]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp
Merge ^/head r278499 through r278755.
[FreeBSD/FreeBSD.git] / contrib / llvm / tools / clang / lib / CodeGen / CodeGenPGO.cpp
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25
26 using namespace clang;
27 using namespace CodeGen;
28
29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   StringRef RawFuncName = Name;
32
33   // Function names may be prefixed with a binary '1' to indicate
34   // that the backend should not modify the symbols due to any platform
35   // naming convention. Do not include that '1' in the PGO profile name.
36   if (RawFuncName[0] == '\1')
37     RawFuncName = RawFuncName.substr(1);
38
39   FuncName = RawFuncName;
40   if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41     // For local symbols, prepend the main file name to distinguish them.
42     // Do not include the full path in the file name since there's no guarantee
43     // that it will stay the same, e.g., if the files are checked out from
44     // version control in different locations.
45     if (CGM.getCodeGenOpts().MainFileName.empty())
46       FuncName = FuncName.insert(0, "<unknown>:");
47     else
48       FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49   }
50
51   // If we're generating a profile, create a variable for the name.
52   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53     createFuncNameVar(Linkage);
54 }
55
56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57   setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59
60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61   // Usually, we want to match the function's linkage, but
62   // available_externally and extern_weak both have the wrong semantics.
63   if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
64     Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
65   else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
66     Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
67
68   auto *Value =
69       llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
70   FuncNameVar =
71       new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
72                                Value, "__llvm_profile_name_" + FuncName);
73
74   // Hide the symbol so that we correctly get a copy for each executable.
75   if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
76     FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
77 }
78
79 namespace {
80 /// \brief Stable hasher for PGO region counters.
81 ///
82 /// PGOHash produces a stable hash of a given function's control flow.
83 ///
84 /// Changing the output of this hash will invalidate all previously generated
85 /// profiles -- i.e., don't do it.
86 ///
87 /// \note  When this hash does eventually change (years?), we still need to
88 /// support old hashes.  We'll need to pull in the version number from the
89 /// profile data format and use the matching hash function.
90 class PGOHash {
91   uint64_t Working;
92   unsigned Count;
93   llvm::MD5 MD5;
94
95   static const int NumBitsPerType = 6;
96   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
97   static const unsigned TooBig = 1u << NumBitsPerType;
98
99 public:
100   /// \brief Hash values for AST nodes.
101   ///
102   /// Distinct values for AST nodes that have region counters attached.
103   ///
104   /// These values must be stable.  All new members must be added at the end,
105   /// and no members should be removed.  Changing the enumeration value for an
106   /// AST node will affect the hash of every function that contains that node.
107   enum HashType : unsigned char {
108     None = 0,
109     LabelStmt = 1,
110     WhileStmt,
111     DoStmt,
112     ForStmt,
113     CXXForRangeStmt,
114     ObjCForCollectionStmt,
115     SwitchStmt,
116     CaseStmt,
117     DefaultStmt,
118     IfStmt,
119     CXXTryStmt,
120     CXXCatchStmt,
121     ConditionalOperator,
122     BinaryOperatorLAnd,
123     BinaryOperatorLOr,
124     BinaryConditionalOperator,
125
126     // Keep this last.  It's for the static assert that follows.
127     LastHashType
128   };
129   static_assert(LastHashType <= TooBig, "Too many types in HashType");
130
131   // TODO: When this format changes, take in a version number here, and use the
132   // old hash calculation for file formats that used the old hash.
133   PGOHash() : Working(0), Count(0) {}
134   void combine(HashType Type);
135   uint64_t finalize();
136 };
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
140
141   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
142   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
143     /// The next counter value to assign.
144     unsigned NextCounter;
145     /// The function hash.
146     PGOHash Hash;
147     /// The map of statements to counters.
148     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
149
150     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
151         : NextCounter(0), CounterMap(CounterMap) {}
152
153     // Blocks and lambdas are handled as separate functions, so we need not
154     // traverse them in the parent context.
155     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
156     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
157     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
158
159     bool VisitDecl(const Decl *D) {
160       switch (D->getKind()) {
161       default:
162         break;
163       case Decl::Function:
164       case Decl::CXXMethod:
165       case Decl::CXXConstructor:
166       case Decl::CXXDestructor:
167       case Decl::CXXConversion:
168       case Decl::ObjCMethod:
169       case Decl::Block:
170       case Decl::Captured:
171         CounterMap[D->getBody()] = NextCounter++;
172         break;
173       }
174       return true;
175     }
176
177     bool VisitStmt(const Stmt *S) {
178       auto Type = getHashType(S);
179       if (Type == PGOHash::None)
180         return true;
181
182       CounterMap[S] = NextCounter++;
183       Hash.combine(Type);
184       return true;
185     }
186     PGOHash::HashType getHashType(const Stmt *S) {
187       switch (S->getStmtClass()) {
188       default:
189         break;
190       case Stmt::LabelStmtClass:
191         return PGOHash::LabelStmt;
192       case Stmt::WhileStmtClass:
193         return PGOHash::WhileStmt;
194       case Stmt::DoStmtClass:
195         return PGOHash::DoStmt;
196       case Stmt::ForStmtClass:
197         return PGOHash::ForStmt;
198       case Stmt::CXXForRangeStmtClass:
199         return PGOHash::CXXForRangeStmt;
200       case Stmt::ObjCForCollectionStmtClass:
201         return PGOHash::ObjCForCollectionStmt;
202       case Stmt::SwitchStmtClass:
203         return PGOHash::SwitchStmt;
204       case Stmt::CaseStmtClass:
205         return PGOHash::CaseStmt;
206       case Stmt::DefaultStmtClass:
207         return PGOHash::DefaultStmt;
208       case Stmt::IfStmtClass:
209         return PGOHash::IfStmt;
210       case Stmt::CXXTryStmtClass:
211         return PGOHash::CXXTryStmt;
212       case Stmt::CXXCatchStmtClass:
213         return PGOHash::CXXCatchStmt;
214       case Stmt::ConditionalOperatorClass:
215         return PGOHash::ConditionalOperator;
216       case Stmt::BinaryConditionalOperatorClass:
217         return PGOHash::BinaryConditionalOperator;
218       case Stmt::BinaryOperatorClass: {
219         const BinaryOperator *BO = cast<BinaryOperator>(S);
220         if (BO->getOpcode() == BO_LAnd)
221           return PGOHash::BinaryOperatorLAnd;
222         if (BO->getOpcode() == BO_LOr)
223           return PGOHash::BinaryOperatorLOr;
224         break;
225       }
226       }
227       return PGOHash::None;
228     }
229   };
230
231   /// A StmtVisitor that propagates the raw counts through the AST and
232   /// records the count at statements where the value may change.
233   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
234     /// PGO state.
235     CodeGenPGO &PGO;
236
237     /// A flag that is set when the current count should be recorded on the
238     /// next statement, such as at the exit of a loop.
239     bool RecordNextStmtCount;
240
241     /// The map of statements to count values.
242     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
243
244     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
245     struct BreakContinue {
246       uint64_t BreakCount;
247       uint64_t ContinueCount;
248       BreakContinue() : BreakCount(0), ContinueCount(0) {}
249     };
250     SmallVector<BreakContinue, 8> BreakContinueStack;
251
252     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
253                         CodeGenPGO &PGO)
254         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
255
256     void RecordStmtCount(const Stmt *S) {
257       if (RecordNextStmtCount) {
258         CountMap[S] = PGO.getCurrentRegionCount();
259         RecordNextStmtCount = false;
260       }
261     }
262
263     void VisitStmt(const Stmt *S) {
264       RecordStmtCount(S);
265       for (Stmt::const_child_range I = S->children(); I; ++I) {
266         if (*I)
267          this->Visit(*I);
268       }
269     }
270
271     void VisitFunctionDecl(const FunctionDecl *D) {
272       // Counter tracks entry to the function body.
273       RegionCounter Cnt(PGO, D->getBody());
274       Cnt.beginRegion();
275       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
276       Visit(D->getBody());
277     }
278
279     // Skip lambda expressions. We visit these as FunctionDecls when we're
280     // generating them and aren't interested in the body when generating a
281     // parent context.
282     void VisitLambdaExpr(const LambdaExpr *LE) {}
283
284     void VisitCapturedDecl(const CapturedDecl *D) {
285       // Counter tracks entry to the capture body.
286       RegionCounter Cnt(PGO, D->getBody());
287       Cnt.beginRegion();
288       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
289       Visit(D->getBody());
290     }
291
292     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
293       // Counter tracks entry to the method body.
294       RegionCounter Cnt(PGO, D->getBody());
295       Cnt.beginRegion();
296       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
297       Visit(D->getBody());
298     }
299
300     void VisitBlockDecl(const BlockDecl *D) {
301       // Counter tracks entry to the block body.
302       RegionCounter Cnt(PGO, D->getBody());
303       Cnt.beginRegion();
304       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
305       Visit(D->getBody());
306     }
307
308     void VisitReturnStmt(const ReturnStmt *S) {
309       RecordStmtCount(S);
310       if (S->getRetValue())
311         Visit(S->getRetValue());
312       PGO.setCurrentRegionUnreachable();
313       RecordNextStmtCount = true;
314     }
315
316     void VisitGotoStmt(const GotoStmt *S) {
317       RecordStmtCount(S);
318       PGO.setCurrentRegionUnreachable();
319       RecordNextStmtCount = true;
320     }
321
322     void VisitLabelStmt(const LabelStmt *S) {
323       RecordNextStmtCount = false;
324       // Counter tracks the block following the label.
325       RegionCounter Cnt(PGO, S);
326       Cnt.beginRegion();
327       CountMap[S] = PGO.getCurrentRegionCount();
328       Visit(S->getSubStmt());
329     }
330
331     void VisitBreakStmt(const BreakStmt *S) {
332       RecordStmtCount(S);
333       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
334       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
335       PGO.setCurrentRegionUnreachable();
336       RecordNextStmtCount = true;
337     }
338
339     void VisitContinueStmt(const ContinueStmt *S) {
340       RecordStmtCount(S);
341       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
342       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
343       PGO.setCurrentRegionUnreachable();
344       RecordNextStmtCount = true;
345     }
346
347     void VisitWhileStmt(const WhileStmt *S) {
348       RecordStmtCount(S);
349       // Counter tracks the body of the loop.
350       RegionCounter Cnt(PGO, S);
351       BreakContinueStack.push_back(BreakContinue());
352       // Visit the body region first so the break/continue adjustments can be
353       // included when visiting the condition.
354       Cnt.beginRegion();
355       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
356       Visit(S->getBody());
357       Cnt.adjustForControlFlow();
358
359       // ...then go back and propagate counts through the condition. The count
360       // at the start of the condition is the sum of the incoming edges,
361       // the backedge from the end of the loop body, and the edges from
362       // continue statements.
363       BreakContinue BC = BreakContinueStack.pop_back_val();
364       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
365                                 Cnt.getAdjustedCount() + BC.ContinueCount);
366       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
367       Visit(S->getCond());
368       Cnt.adjustForControlFlow();
369       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
370       RecordNextStmtCount = true;
371     }
372
373     void VisitDoStmt(const DoStmt *S) {
374       RecordStmtCount(S);
375       // Counter tracks the body of the loop.
376       RegionCounter Cnt(PGO, S);
377       BreakContinueStack.push_back(BreakContinue());
378       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
379       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
380       Visit(S->getBody());
381       Cnt.adjustForControlFlow();
382
383       BreakContinue BC = BreakContinueStack.pop_back_val();
384       // The count at the start of the condition is equal to the count at the
385       // end of the body. The adjusted count does not include either the
386       // fall-through count coming into the loop or the continue count, so add
387       // both of those separately. This is coincidentally the same equation as
388       // with while loops but for different reasons.
389       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
390                                 Cnt.getAdjustedCount() + BC.ContinueCount);
391       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
392       Visit(S->getCond());
393       Cnt.adjustForControlFlow();
394       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
395       RecordNextStmtCount = true;
396     }
397
398     void VisitForStmt(const ForStmt *S) {
399       RecordStmtCount(S);
400       if (S->getInit())
401         Visit(S->getInit());
402       // Counter tracks the body of the loop.
403       RegionCounter Cnt(PGO, S);
404       BreakContinueStack.push_back(BreakContinue());
405       // Visit the body region first. (This is basically the same as a while
406       // loop; see further comments in VisitWhileStmt.)
407       Cnt.beginRegion();
408       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
409       Visit(S->getBody());
410       Cnt.adjustForControlFlow();
411
412       // The increment is essentially part of the body but it needs to include
413       // the count for all the continue statements.
414       if (S->getInc()) {
415         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
416                                   BreakContinueStack.back().ContinueCount);
417         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
418         Visit(S->getInc());
419         Cnt.adjustForControlFlow();
420       }
421
422       BreakContinue BC = BreakContinueStack.pop_back_val();
423
424       // ...then go back and propagate counts through the condition.
425       if (S->getCond()) {
426         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
427                                   Cnt.getAdjustedCount() +
428                                   BC.ContinueCount);
429         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
430         Visit(S->getCond());
431         Cnt.adjustForControlFlow();
432       }
433       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
434       RecordNextStmtCount = true;
435     }
436
437     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
438       RecordStmtCount(S);
439       Visit(S->getRangeStmt());
440       Visit(S->getBeginEndStmt());
441       // Counter tracks the body of the loop.
442       RegionCounter Cnt(PGO, S);
443       BreakContinueStack.push_back(BreakContinue());
444       // Visit the body region first. (This is basically the same as a while
445       // loop; see further comments in VisitWhileStmt.)
446       Cnt.beginRegion();
447       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
448       Visit(S->getLoopVarStmt());
449       Visit(S->getBody());
450       Cnt.adjustForControlFlow();
451
452       // The increment is essentially part of the body but it needs to include
453       // the count for all the continue statements.
454       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
455                                 BreakContinueStack.back().ContinueCount);
456       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
457       Visit(S->getInc());
458       Cnt.adjustForControlFlow();
459
460       BreakContinue BC = BreakContinueStack.pop_back_val();
461
462       // ...then go back and propagate counts through the condition.
463       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
464                                 Cnt.getAdjustedCount() +
465                                 BC.ContinueCount);
466       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
467       Visit(S->getCond());
468       Cnt.adjustForControlFlow();
469       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
470       RecordNextStmtCount = true;
471     }
472
473     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
474       RecordStmtCount(S);
475       Visit(S->getElement());
476       // Counter tracks the body of the loop.
477       RegionCounter Cnt(PGO, S);
478       BreakContinueStack.push_back(BreakContinue());
479       Cnt.beginRegion();
480       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
481       Visit(S->getBody());
482       BreakContinue BC = BreakContinueStack.pop_back_val();
483       Cnt.adjustForControlFlow();
484       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
485       RecordNextStmtCount = true;
486     }
487
488     void VisitSwitchStmt(const SwitchStmt *S) {
489       RecordStmtCount(S);
490       Visit(S->getCond());
491       PGO.setCurrentRegionUnreachable();
492       BreakContinueStack.push_back(BreakContinue());
493       Visit(S->getBody());
494       // If the switch is inside a loop, add the continue counts.
495       BreakContinue BC = BreakContinueStack.pop_back_val();
496       if (!BreakContinueStack.empty())
497         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
498       // Counter tracks the exit block of the switch.
499       RegionCounter ExitCnt(PGO, S);
500       ExitCnt.beginRegion();
501       RecordNextStmtCount = true;
502     }
503
504     void VisitCaseStmt(const CaseStmt *S) {
505       RecordNextStmtCount = false;
506       // Counter for this particular case. This counts only jumps from the
507       // switch header and does not include fallthrough from the case before
508       // this one.
509       RegionCounter Cnt(PGO, S);
510       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
511       CountMap[S] = Cnt.getCount();
512       RecordNextStmtCount = true;
513       Visit(S->getSubStmt());
514     }
515
516     void VisitDefaultStmt(const DefaultStmt *S) {
517       RecordNextStmtCount = false;
518       // Counter for this default case. This does not include fallthrough from
519       // the previous case.
520       RegionCounter Cnt(PGO, S);
521       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
522       CountMap[S] = Cnt.getCount();
523       RecordNextStmtCount = true;
524       Visit(S->getSubStmt());
525     }
526
527     void VisitIfStmt(const IfStmt *S) {
528       RecordStmtCount(S);
529       // Counter tracks the "then" part of an if statement. The count for
530       // the "else" part, if it exists, will be calculated from this counter.
531       RegionCounter Cnt(PGO, S);
532       Visit(S->getCond());
533
534       Cnt.beginRegion();
535       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
536       Visit(S->getThen());
537       Cnt.adjustForControlFlow();
538
539       if (S->getElse()) {
540         Cnt.beginElseRegion();
541         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
542         Visit(S->getElse());
543         Cnt.adjustForControlFlow();
544       }
545       Cnt.applyAdjustmentsToRegion(0);
546       RecordNextStmtCount = true;
547     }
548
549     void VisitCXXTryStmt(const CXXTryStmt *S) {
550       RecordStmtCount(S);
551       Visit(S->getTryBlock());
552       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
553         Visit(S->getHandler(I));
554       // Counter tracks the continuation block of the try statement.
555       RegionCounter Cnt(PGO, S);
556       Cnt.beginRegion();
557       RecordNextStmtCount = true;
558     }
559
560     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
561       RecordNextStmtCount = false;
562       // Counter tracks the catch statement's handler block.
563       RegionCounter Cnt(PGO, S);
564       Cnt.beginRegion();
565       CountMap[S] = PGO.getCurrentRegionCount();
566       Visit(S->getHandlerBlock());
567     }
568
569     void VisitAbstractConditionalOperator(
570         const AbstractConditionalOperator *E) {
571       RecordStmtCount(E);
572       // Counter tracks the "true" part of a conditional operator. The
573       // count in the "false" part will be calculated from this counter.
574       RegionCounter Cnt(PGO, E);
575       Visit(E->getCond());
576
577       Cnt.beginRegion();
578       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
579       Visit(E->getTrueExpr());
580       Cnt.adjustForControlFlow();
581
582       Cnt.beginElseRegion();
583       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
584       Visit(E->getFalseExpr());
585       Cnt.adjustForControlFlow();
586
587       Cnt.applyAdjustmentsToRegion(0);
588       RecordNextStmtCount = true;
589     }
590
591     void VisitBinLAnd(const BinaryOperator *E) {
592       RecordStmtCount(E);
593       // Counter tracks the right hand side of a logical and operator.
594       RegionCounter Cnt(PGO, E);
595       Visit(E->getLHS());
596       Cnt.beginRegion();
597       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
598       Visit(E->getRHS());
599       Cnt.adjustForControlFlow();
600       Cnt.applyAdjustmentsToRegion(0);
601       RecordNextStmtCount = true;
602     }
603
604     void VisitBinLOr(const BinaryOperator *E) {
605       RecordStmtCount(E);
606       // Counter tracks the right hand side of a logical or operator.
607       RegionCounter Cnt(PGO, E);
608       Visit(E->getLHS());
609       Cnt.beginRegion();
610       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
611       Visit(E->getRHS());
612       Cnt.adjustForControlFlow();
613       Cnt.applyAdjustmentsToRegion(0);
614       RecordNextStmtCount = true;
615     }
616   };
617 }
618
619 void PGOHash::combine(HashType Type) {
620   // Check that we never combine 0 and only have six bits.
621   assert(Type && "Hash is invalid: unexpected type 0");
622   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
623
624   // Pass through MD5 if enough work has built up.
625   if (Count && Count % NumTypesPerWord == 0) {
626     using namespace llvm::support;
627     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
628     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
629     Working = 0;
630   }
631
632   // Accumulate the current type.
633   ++Count;
634   Working = Working << NumBitsPerType | Type;
635 }
636
637 uint64_t PGOHash::finalize() {
638   // Use Working as the hash directly if we never used MD5.
639   if (Count <= NumTypesPerWord)
640     // No need to byte swap here, since none of the math was endian-dependent.
641     // This number will be byte-swapped as required on endianness transitions,
642     // so we will see the same value on the other side.
643     return Working;
644
645   // Check for remaining work in Working.
646   if (Working)
647     MD5.update(Working);
648
649   // Finalize the MD5 and return the hash.
650   llvm::MD5::MD5Result Result;
651   MD5.final(Result);
652   using namespace llvm::support;
653   return endian::read<uint64_t, little, unaligned>(Result);
654 }
655
656 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
657   // Make sure we only emit coverage mapping for one constructor/destructor.
658   // Clang emits several functions for the constructor and the destructor of
659   // a class. Every function is instrumented, but we only want to provide
660   // coverage for one of them. Because of that we only emit the coverage mapping
661   // for the base constructor/destructor.
662   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
663        GD.getCtorType() != Ctor_Base) ||
664       (isa<CXXDestructorDecl>(GD.getDecl()) &&
665        GD.getDtorType() != Dtor_Base)) {
666     SkipCoverageMapping = true;
667   }
668 }
669
670 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
671   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
672   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
673   if (!InstrumentRegions && !PGOReader)
674     return;
675   if (D->isImplicit())
676     return;
677   CGM.ClearUnusedCoverageMapping(D);
678   setFuncName(Fn);
679
680   mapRegionCounters(D);
681   if (CGM.getCodeGenOpts().CoverageMapping)
682     emitCounterRegionMapping(D);
683   if (PGOReader) {
684     SourceManager &SM = CGM.getContext().getSourceManager();
685     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
686     computeRegionCounts(D);
687     applyFunctionAttributes(PGOReader, Fn);
688   }
689 }
690
691 void CodeGenPGO::mapRegionCounters(const Decl *D) {
692   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
693   MapRegionCounters Walker(*RegionCounterMap);
694   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
695     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
696   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
697     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
698   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
699     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
700   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
701     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
702   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
703   NumRegionCounters = Walker.NextCounter;
704   FunctionHash = Walker.Hash.finalize();
705 }
706
707 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
708   if (SkipCoverageMapping)
709     return;
710   // Don't map the functions inside the system headers
711   auto Loc = D->getBody()->getLocStart();
712   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
713     return;
714
715   std::string CoverageMapping;
716   llvm::raw_string_ostream OS(CoverageMapping);
717   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
718                                 CGM.getContext().getSourceManager(),
719                                 CGM.getLangOpts(), RegionCounterMap.get());
720   MappingGen.emitCounterMapping(D, OS);
721   OS.flush();
722
723   if (CoverageMapping.empty())
724     return;
725
726   CGM.getCoverageMapping()->addFunctionMappingRecord(
727       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
728 }
729
730 void
731 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
732                                     llvm::GlobalValue::LinkageTypes Linkage) {
733   if (SkipCoverageMapping)
734     return;
735   setFuncName(FuncName, Linkage);
736
737   // Don't map the functions inside the system headers
738   auto Loc = D->getBody()->getLocStart();
739   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
740     return;
741
742   std::string CoverageMapping;
743   llvm::raw_string_ostream OS(CoverageMapping);
744   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
745                                 CGM.getContext().getSourceManager(),
746                                 CGM.getLangOpts());
747   MappingGen.emitEmptyMapping(D, OS);
748   OS.flush();
749
750   if (CoverageMapping.empty())
751     return;
752
753   CGM.getCoverageMapping()->addFunctionMappingRecord(
754       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
755 }
756
757 void CodeGenPGO::computeRegionCounts(const Decl *D) {
758   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
759   ComputeRegionCounts Walker(*StmtCountMap, *this);
760   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
761     Walker.VisitFunctionDecl(FD);
762   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
763     Walker.VisitObjCMethodDecl(MD);
764   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
765     Walker.VisitBlockDecl(BD);
766   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
767     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
768 }
769
770 void
771 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
772                                     llvm::Function *Fn) {
773   if (!haveRegionCounts())
774     return;
775
776   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
777   uint64_t FunctionCount = getRegionCount(0);
778   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
779     // Turn on InlineHint attribute for hot functions.
780     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
781     Fn->addFnAttr(llvm::Attribute::InlineHint);
782   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
783     // Turn on Cold attribute for cold functions.
784     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
785     Fn->addFnAttr(llvm::Attribute::Cold);
786 }
787
788 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
789   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
790     return;
791   if (!Builder.GetInsertPoint())
792     return;
793   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
794   Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
795                       llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
796                       Builder.getInt64(FunctionHash),
797                       Builder.getInt32(NumRegionCounters),
798                       Builder.getInt32(Counter));
799 }
800
801 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
802                                   bool IsInMainFile) {
803   CGM.getPGOStats().addVisited(IsInMainFile);
804   RegionCounts.clear();
805   if (std::error_code EC =
806           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
807     if (EC == llvm::instrprof_error::unknown_function)
808       CGM.getPGOStats().addMissing(IsInMainFile);
809     else if (EC == llvm::instrprof_error::hash_mismatch)
810       CGM.getPGOStats().addMismatched(IsInMainFile);
811     else if (EC == llvm::instrprof_error::malformed)
812       // TODO: Consider a more specific warning for this case.
813       CGM.getPGOStats().addMismatched(IsInMainFile);
814     RegionCounts.clear();
815   }
816 }
817
818 /// \brief Calculate what to divide by to scale weights.
819 ///
820 /// Given the maximum weight, calculate a divisor that will scale all the
821 /// weights to strictly less than UINT32_MAX.
822 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
823   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
824 }
825
826 /// \brief Scale an individual branch weight (and add 1).
827 ///
828 /// Scale a 64-bit weight down to 32-bits using \c Scale.
829 ///
830 /// According to Laplace's Rule of Succession, it is better to compute the
831 /// weight based on the count plus 1, so universally add 1 to the value.
832 ///
833 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
834 /// greater than \c Weight.
835 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
836   assert(Scale && "scale by 0?");
837   uint64_t Scaled = Weight / Scale + 1;
838   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
839   return Scaled;
840 }
841
842 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
843                                               uint64_t FalseCount) {
844   // Check for empty weights.
845   if (!TrueCount && !FalseCount)
846     return nullptr;
847
848   // Calculate how to scale down to 32-bits.
849   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
850
851   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
852   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
853                                       scaleBranchWeight(FalseCount, Scale));
854 }
855
856 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
857   // We need at least two elements to create meaningful weights.
858   if (Weights.size() < 2)
859     return nullptr;
860
861   // Check for empty weights.
862   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
863   if (MaxWeight == 0)
864     return nullptr;
865
866   // Calculate how to scale down to 32-bits.
867   uint64_t Scale = calculateWeightScale(MaxWeight);
868
869   SmallVector<uint32_t, 16> ScaledWeights;
870   ScaledWeights.reserve(Weights.size());
871   for (uint64_t W : Weights)
872     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
873
874   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
875   return MDHelper.createBranchWeights(ScaledWeights);
876 }
877
878 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
879                                             RegionCounter &Cnt) {
880   if (!haveRegionCounts())
881     return nullptr;
882   uint64_t LoopCount = Cnt.getCount();
883   uint64_t CondCount = 0;
884   bool Found = getStmtCount(Cond, CondCount);
885   assert(Found && "missing expected loop condition count");
886   (void)Found;
887   if (CondCount == 0)
888     return nullptr;
889   return createBranchWeights(LoopCount,
890                              std::max(CondCount, LoopCount) - LoopCount);
891 }