]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp
Merge clang 7.0.1 and several follow-up changes
[FreeBSD/FreeBSD.git] / contrib / llvm / tools / clang / lib / CodeGen / CodeGenPGO.cpp
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29
30 using namespace clang;
31 using namespace CodeGen;
32
33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55
56   // Keep this set to the latest hash version.
57   PGO_HASH_LATEST = PGO_HASH_V2
58 };
59
60 namespace {
61 /// Stable hasher for PGO region counters.
62 ///
63 /// PGOHash produces a stable hash of a given function's control flow.
64 ///
65 /// Changing the output of this hash will invalidate all previously generated
66 /// profiles -- i.e., don't do it.
67 ///
68 /// \note  When this hash does eventually change (years?), we still need to
69 /// support old hashes.  We'll need to pull in the version number from the
70 /// profile data format and use the matching hash function.
71 class PGOHash {
72   uint64_t Working;
73   unsigned Count;
74   PGOHashVersion HashVersion;
75   llvm::MD5 MD5;
76
77   static const int NumBitsPerType = 6;
78   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79   static const unsigned TooBig = 1u << NumBitsPerType;
80
81 public:
82   /// Hash values for AST nodes.
83   ///
84   /// Distinct values for AST nodes that have region counters attached.
85   ///
86   /// These values must be stable.  All new members must be added at the end,
87   /// and no members should be removed.  Changing the enumeration value for an
88   /// AST node will affect the hash of every function that contains that node.
89   enum HashType : unsigned char {
90     None = 0,
91     LabelStmt = 1,
92     WhileStmt,
93     DoStmt,
94     ForStmt,
95     CXXForRangeStmt,
96     ObjCForCollectionStmt,
97     SwitchStmt,
98     CaseStmt,
99     DefaultStmt,
100     IfStmt,
101     CXXTryStmt,
102     CXXCatchStmt,
103     ConditionalOperator,
104     BinaryOperatorLAnd,
105     BinaryOperatorLOr,
106     BinaryConditionalOperator,
107     // The preceding values are available with PGO_HASH_V1.
108
109     EndOfScope,
110     IfThenBranch,
111     IfElseBranch,
112     GotoStmt,
113     IndirectGotoStmt,
114     BreakStmt,
115     ContinueStmt,
116     ReturnStmt,
117     ThrowExpr,
118     UnaryOperatorLNot,
119     BinaryOperatorLT,
120     BinaryOperatorGT,
121     BinaryOperatorLE,
122     BinaryOperatorGE,
123     BinaryOperatorEQ,
124     BinaryOperatorNE,
125     // The preceding values are available with PGO_HASH_V2.
126
127     // Keep this last.  It's for the static assert that follows.
128     LastHashType
129   };
130   static_assert(LastHashType <= TooBig, "Too many types in HashType");
131
132   PGOHash(PGOHashVersion HashVersion)
133       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134   void combine(HashType Type);
135   uint64_t finalize();
136   PGOHashVersion getHashVersion() const { return HashVersion; }
137 };
138 const int PGOHash::NumBitsPerType;
139 const unsigned PGOHash::NumTypesPerWord;
140 const unsigned PGOHash::TooBig;
141
142 /// Get the PGO hash version used in the given indexed profile.
143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144                                         CodeGenModule &CGM) {
145   if (PGOReader->getVersion() <= 4)
146     return PGO_HASH_V1;
147   return PGO_HASH_V2;
148 }
149
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152   using Base = RecursiveASTVisitor<MapRegionCounters>;
153
154   /// The next counter value to assign.
155   unsigned NextCounter;
156   /// The function hash.
157   PGOHash Hash;
158   /// The map of statements to counters.
159   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160
161   MapRegionCounters(PGOHashVersion HashVersion,
162                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164
165   // Blocks and lambdas are handled as separate functions, so we need not
166   // traverse them in the parent context.
167   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
169   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
170
171   bool VisitDecl(const Decl *D) {
172     switch (D->getKind()) {
173     default:
174       break;
175     case Decl::Function:
176     case Decl::CXXMethod:
177     case Decl::CXXConstructor:
178     case Decl::CXXDestructor:
179     case Decl::CXXConversion:
180     case Decl::ObjCMethod:
181     case Decl::Block:
182     case Decl::Captured:
183       CounterMap[D->getBody()] = NextCounter++;
184       break;
185     }
186     return true;
187   }
188
189   /// If \p S gets a fresh counter, update the counter mappings. Return the
190   /// V1 hash of \p S.
191   PGOHash::HashType updateCounterMappings(Stmt *S) {
192     auto Type = getHashType(PGO_HASH_V1, S);
193     if (Type != PGOHash::None)
194       CounterMap[S] = NextCounter++;
195     return Type;
196   }
197
198   /// Include \p S in the function hash.
199   bool VisitStmt(Stmt *S) {
200     auto Type = updateCounterMappings(S);
201     if (Hash.getHashVersion() != PGO_HASH_V1)
202       Type = getHashType(Hash.getHashVersion(), S);
203     if (Type != PGOHash::None)
204       Hash.combine(Type);
205     return true;
206   }
207
208   bool TraverseIfStmt(IfStmt *If) {
209     // If we used the V1 hash, use the default traversal.
210     if (Hash.getHashVersion() == PGO_HASH_V1)
211       return Base::TraverseIfStmt(If);
212
213     // Otherwise, keep track of which branch we're in while traversing.
214     VisitStmt(If);
215     for (Stmt *CS : If->children()) {
216       if (!CS)
217         continue;
218       if (CS == If->getThen())
219         Hash.combine(PGOHash::IfThenBranch);
220       else if (CS == If->getElse())
221         Hash.combine(PGOHash::IfElseBranch);
222       TraverseStmt(CS);
223     }
224     Hash.combine(PGOHash::EndOfScope);
225     return true;
226   }
227
228 // If the statement type \p N is nestable, and its nesting impacts profile
229 // stability, define a custom traversal which tracks the end of the statement
230 // in the hash (provided we're not using the V1 hash).
231 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
232   bool Traverse##N(N *S) {                                                     \
233     Base::Traverse##N(S);                                                      \
234     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
235       Hash.combine(PGOHash::EndOfScope);                                       \
236     return true;                                                               \
237   }
238
239   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
240   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
241   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
242   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
243   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
244   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
245   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
246
247   /// Get version \p HashVersion of the PGO hash for \p S.
248   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
249     switch (S->getStmtClass()) {
250     default:
251       break;
252     case Stmt::LabelStmtClass:
253       return PGOHash::LabelStmt;
254     case Stmt::WhileStmtClass:
255       return PGOHash::WhileStmt;
256     case Stmt::DoStmtClass:
257       return PGOHash::DoStmt;
258     case Stmt::ForStmtClass:
259       return PGOHash::ForStmt;
260     case Stmt::CXXForRangeStmtClass:
261       return PGOHash::CXXForRangeStmt;
262     case Stmt::ObjCForCollectionStmtClass:
263       return PGOHash::ObjCForCollectionStmt;
264     case Stmt::SwitchStmtClass:
265       return PGOHash::SwitchStmt;
266     case Stmt::CaseStmtClass:
267       return PGOHash::CaseStmt;
268     case Stmt::DefaultStmtClass:
269       return PGOHash::DefaultStmt;
270     case Stmt::IfStmtClass:
271       return PGOHash::IfStmt;
272     case Stmt::CXXTryStmtClass:
273       return PGOHash::CXXTryStmt;
274     case Stmt::CXXCatchStmtClass:
275       return PGOHash::CXXCatchStmt;
276     case Stmt::ConditionalOperatorClass:
277       return PGOHash::ConditionalOperator;
278     case Stmt::BinaryConditionalOperatorClass:
279       return PGOHash::BinaryConditionalOperator;
280     case Stmt::BinaryOperatorClass: {
281       const BinaryOperator *BO = cast<BinaryOperator>(S);
282       if (BO->getOpcode() == BO_LAnd)
283         return PGOHash::BinaryOperatorLAnd;
284       if (BO->getOpcode() == BO_LOr)
285         return PGOHash::BinaryOperatorLOr;
286       if (HashVersion == PGO_HASH_V2) {
287         switch (BO->getOpcode()) {
288         default:
289           break;
290         case BO_LT:
291           return PGOHash::BinaryOperatorLT;
292         case BO_GT:
293           return PGOHash::BinaryOperatorGT;
294         case BO_LE:
295           return PGOHash::BinaryOperatorLE;
296         case BO_GE:
297           return PGOHash::BinaryOperatorGE;
298         case BO_EQ:
299           return PGOHash::BinaryOperatorEQ;
300         case BO_NE:
301           return PGOHash::BinaryOperatorNE;
302         }
303       }
304       break;
305     }
306     }
307
308     if (HashVersion == PGO_HASH_V2) {
309       switch (S->getStmtClass()) {
310       default:
311         break;
312       case Stmt::GotoStmtClass:
313         return PGOHash::GotoStmt;
314       case Stmt::IndirectGotoStmtClass:
315         return PGOHash::IndirectGotoStmt;
316       case Stmt::BreakStmtClass:
317         return PGOHash::BreakStmt;
318       case Stmt::ContinueStmtClass:
319         return PGOHash::ContinueStmt;
320       case Stmt::ReturnStmtClass:
321         return PGOHash::ReturnStmt;
322       case Stmt::CXXThrowExprClass:
323         return PGOHash::ThrowExpr;
324       case Stmt::UnaryOperatorClass: {
325         const UnaryOperator *UO = cast<UnaryOperator>(S);
326         if (UO->getOpcode() == UO_LNot)
327           return PGOHash::UnaryOperatorLNot;
328         break;
329       }
330       }
331     }
332
333     return PGOHash::None;
334   }
335 };
336
337 /// A StmtVisitor that propagates the raw counts through the AST and
338 /// records the count at statements where the value may change.
339 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
340   /// PGO state.
341   CodeGenPGO &PGO;
342
343   /// A flag that is set when the current count should be recorded on the
344   /// next statement, such as at the exit of a loop.
345   bool RecordNextStmtCount;
346
347   /// The count at the current location in the traversal.
348   uint64_t CurrentCount;
349
350   /// The map of statements to count values.
351   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
352
353   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
354   struct BreakContinue {
355     uint64_t BreakCount;
356     uint64_t ContinueCount;
357     BreakContinue() : BreakCount(0), ContinueCount(0) {}
358   };
359   SmallVector<BreakContinue, 8> BreakContinueStack;
360
361   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
362                       CodeGenPGO &PGO)
363       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
364
365   void RecordStmtCount(const Stmt *S) {
366     if (RecordNextStmtCount) {
367       CountMap[S] = CurrentCount;
368       RecordNextStmtCount = false;
369     }
370   }
371
372   /// Set and return the current count.
373   uint64_t setCount(uint64_t Count) {
374     CurrentCount = Count;
375     return Count;
376   }
377
378   void VisitStmt(const Stmt *S) {
379     RecordStmtCount(S);
380     for (const Stmt *Child : S->children())
381       if (Child)
382         this->Visit(Child);
383   }
384
385   void VisitFunctionDecl(const FunctionDecl *D) {
386     // Counter tracks entry to the function body.
387     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
388     CountMap[D->getBody()] = BodyCount;
389     Visit(D->getBody());
390   }
391
392   // Skip lambda expressions. We visit these as FunctionDecls when we're
393   // generating them and aren't interested in the body when generating a
394   // parent context.
395   void VisitLambdaExpr(const LambdaExpr *LE) {}
396
397   void VisitCapturedDecl(const CapturedDecl *D) {
398     // Counter tracks entry to the capture body.
399     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
400     CountMap[D->getBody()] = BodyCount;
401     Visit(D->getBody());
402   }
403
404   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
405     // Counter tracks entry to the method body.
406     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
407     CountMap[D->getBody()] = BodyCount;
408     Visit(D->getBody());
409   }
410
411   void VisitBlockDecl(const BlockDecl *D) {
412     // Counter tracks entry to the block body.
413     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
414     CountMap[D->getBody()] = BodyCount;
415     Visit(D->getBody());
416   }
417
418   void VisitReturnStmt(const ReturnStmt *S) {
419     RecordStmtCount(S);
420     if (S->getRetValue())
421       Visit(S->getRetValue());
422     CurrentCount = 0;
423     RecordNextStmtCount = true;
424   }
425
426   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
427     RecordStmtCount(E);
428     if (E->getSubExpr())
429       Visit(E->getSubExpr());
430     CurrentCount = 0;
431     RecordNextStmtCount = true;
432   }
433
434   void VisitGotoStmt(const GotoStmt *S) {
435     RecordStmtCount(S);
436     CurrentCount = 0;
437     RecordNextStmtCount = true;
438   }
439
440   void VisitLabelStmt(const LabelStmt *S) {
441     RecordNextStmtCount = false;
442     // Counter tracks the block following the label.
443     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
444     CountMap[S] = BlockCount;
445     Visit(S->getSubStmt());
446   }
447
448   void VisitBreakStmt(const BreakStmt *S) {
449     RecordStmtCount(S);
450     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
451     BreakContinueStack.back().BreakCount += CurrentCount;
452     CurrentCount = 0;
453     RecordNextStmtCount = true;
454   }
455
456   void VisitContinueStmt(const ContinueStmt *S) {
457     RecordStmtCount(S);
458     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
459     BreakContinueStack.back().ContinueCount += CurrentCount;
460     CurrentCount = 0;
461     RecordNextStmtCount = true;
462   }
463
464   void VisitWhileStmt(const WhileStmt *S) {
465     RecordStmtCount(S);
466     uint64_t ParentCount = CurrentCount;
467
468     BreakContinueStack.push_back(BreakContinue());
469     // Visit the body region first so the break/continue adjustments can be
470     // included when visiting the condition.
471     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
472     CountMap[S->getBody()] = CurrentCount;
473     Visit(S->getBody());
474     uint64_t BackedgeCount = CurrentCount;
475
476     // ...then go back and propagate counts through the condition. The count
477     // at the start of the condition is the sum of the incoming edges,
478     // the backedge from the end of the loop body, and the edges from
479     // continue statements.
480     BreakContinue BC = BreakContinueStack.pop_back_val();
481     uint64_t CondCount =
482         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
483     CountMap[S->getCond()] = CondCount;
484     Visit(S->getCond());
485     setCount(BC.BreakCount + CondCount - BodyCount);
486     RecordNextStmtCount = true;
487   }
488
489   void VisitDoStmt(const DoStmt *S) {
490     RecordStmtCount(S);
491     uint64_t LoopCount = PGO.getRegionCount(S);
492
493     BreakContinueStack.push_back(BreakContinue());
494     // The count doesn't include the fallthrough from the parent scope. Add it.
495     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
496     CountMap[S->getBody()] = BodyCount;
497     Visit(S->getBody());
498     uint64_t BackedgeCount = CurrentCount;
499
500     BreakContinue BC = BreakContinueStack.pop_back_val();
501     // The count at the start of the condition is equal to the count at the
502     // end of the body, plus any continues.
503     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
504     CountMap[S->getCond()] = CondCount;
505     Visit(S->getCond());
506     setCount(BC.BreakCount + CondCount - LoopCount);
507     RecordNextStmtCount = true;
508   }
509
510   void VisitForStmt(const ForStmt *S) {
511     RecordStmtCount(S);
512     if (S->getInit())
513       Visit(S->getInit());
514
515     uint64_t ParentCount = CurrentCount;
516
517     BreakContinueStack.push_back(BreakContinue());
518     // Visit the body region first. (This is basically the same as a while
519     // loop; see further comments in VisitWhileStmt.)
520     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
521     CountMap[S->getBody()] = BodyCount;
522     Visit(S->getBody());
523     uint64_t BackedgeCount = CurrentCount;
524     BreakContinue BC = BreakContinueStack.pop_back_val();
525
526     // The increment is essentially part of the body but it needs to include
527     // the count for all the continue statements.
528     if (S->getInc()) {
529       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
530       CountMap[S->getInc()] = IncCount;
531       Visit(S->getInc());
532     }
533
534     // ...then go back and propagate counts through the condition.
535     uint64_t CondCount =
536         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
537     if (S->getCond()) {
538       CountMap[S->getCond()] = CondCount;
539       Visit(S->getCond());
540     }
541     setCount(BC.BreakCount + CondCount - BodyCount);
542     RecordNextStmtCount = true;
543   }
544
545   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
546     RecordStmtCount(S);
547     Visit(S->getLoopVarStmt());
548     Visit(S->getRangeStmt());
549     Visit(S->getBeginStmt());
550     Visit(S->getEndStmt());
551
552     uint64_t ParentCount = CurrentCount;
553     BreakContinueStack.push_back(BreakContinue());
554     // Visit the body region first. (This is basically the same as a while
555     // loop; see further comments in VisitWhileStmt.)
556     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
557     CountMap[S->getBody()] = BodyCount;
558     Visit(S->getBody());
559     uint64_t BackedgeCount = CurrentCount;
560     BreakContinue BC = BreakContinueStack.pop_back_val();
561
562     // The increment is essentially part of the body but it needs to include
563     // the count for all the continue statements.
564     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
565     CountMap[S->getInc()] = IncCount;
566     Visit(S->getInc());
567
568     // ...then go back and propagate counts through the condition.
569     uint64_t CondCount =
570         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
571     CountMap[S->getCond()] = CondCount;
572     Visit(S->getCond());
573     setCount(BC.BreakCount + CondCount - BodyCount);
574     RecordNextStmtCount = true;
575   }
576
577   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
578     RecordStmtCount(S);
579     Visit(S->getElement());
580     uint64_t ParentCount = CurrentCount;
581     BreakContinueStack.push_back(BreakContinue());
582     // Counter tracks the body of the loop.
583     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
584     CountMap[S->getBody()] = BodyCount;
585     Visit(S->getBody());
586     uint64_t BackedgeCount = CurrentCount;
587     BreakContinue BC = BreakContinueStack.pop_back_val();
588
589     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
590              BodyCount);
591     RecordNextStmtCount = true;
592   }
593
594   void VisitSwitchStmt(const SwitchStmt *S) {
595     RecordStmtCount(S);
596     if (S->getInit())
597       Visit(S->getInit());
598     Visit(S->getCond());
599     CurrentCount = 0;
600     BreakContinueStack.push_back(BreakContinue());
601     Visit(S->getBody());
602     // If the switch is inside a loop, add the continue counts.
603     BreakContinue BC = BreakContinueStack.pop_back_val();
604     if (!BreakContinueStack.empty())
605       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
606     // Counter tracks the exit block of the switch.
607     setCount(PGO.getRegionCount(S));
608     RecordNextStmtCount = true;
609   }
610
611   void VisitSwitchCase(const SwitchCase *S) {
612     RecordNextStmtCount = false;
613     // Counter for this particular case. This counts only jumps from the
614     // switch header and does not include fallthrough from the case before
615     // this one.
616     uint64_t CaseCount = PGO.getRegionCount(S);
617     setCount(CurrentCount + CaseCount);
618     // We need the count without fallthrough in the mapping, so it's more useful
619     // for branch probabilities.
620     CountMap[S] = CaseCount;
621     RecordNextStmtCount = true;
622     Visit(S->getSubStmt());
623   }
624
625   void VisitIfStmt(const IfStmt *S) {
626     RecordStmtCount(S);
627     uint64_t ParentCount = CurrentCount;
628     if (S->getInit())
629       Visit(S->getInit());
630     Visit(S->getCond());
631
632     // Counter tracks the "then" part of an if statement. The count for
633     // the "else" part, if it exists, will be calculated from this counter.
634     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
635     CountMap[S->getThen()] = ThenCount;
636     Visit(S->getThen());
637     uint64_t OutCount = CurrentCount;
638
639     uint64_t ElseCount = ParentCount - ThenCount;
640     if (S->getElse()) {
641       setCount(ElseCount);
642       CountMap[S->getElse()] = ElseCount;
643       Visit(S->getElse());
644       OutCount += CurrentCount;
645     } else
646       OutCount += ElseCount;
647     setCount(OutCount);
648     RecordNextStmtCount = true;
649   }
650
651   void VisitCXXTryStmt(const CXXTryStmt *S) {
652     RecordStmtCount(S);
653     Visit(S->getTryBlock());
654     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
655       Visit(S->getHandler(I));
656     // Counter tracks the continuation block of the try statement.
657     setCount(PGO.getRegionCount(S));
658     RecordNextStmtCount = true;
659   }
660
661   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
662     RecordNextStmtCount = false;
663     // Counter tracks the catch statement's handler block.
664     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
665     CountMap[S] = CatchCount;
666     Visit(S->getHandlerBlock());
667   }
668
669   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
670     RecordStmtCount(E);
671     uint64_t ParentCount = CurrentCount;
672     Visit(E->getCond());
673
674     // Counter tracks the "true" part of a conditional operator. The
675     // count in the "false" part will be calculated from this counter.
676     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
677     CountMap[E->getTrueExpr()] = TrueCount;
678     Visit(E->getTrueExpr());
679     uint64_t OutCount = CurrentCount;
680
681     uint64_t FalseCount = setCount(ParentCount - TrueCount);
682     CountMap[E->getFalseExpr()] = FalseCount;
683     Visit(E->getFalseExpr());
684     OutCount += CurrentCount;
685
686     setCount(OutCount);
687     RecordNextStmtCount = true;
688   }
689
690   void VisitBinLAnd(const BinaryOperator *E) {
691     RecordStmtCount(E);
692     uint64_t ParentCount = CurrentCount;
693     Visit(E->getLHS());
694     // Counter tracks the right hand side of a logical and operator.
695     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
696     CountMap[E->getRHS()] = RHSCount;
697     Visit(E->getRHS());
698     setCount(ParentCount + RHSCount - CurrentCount);
699     RecordNextStmtCount = true;
700   }
701
702   void VisitBinLOr(const BinaryOperator *E) {
703     RecordStmtCount(E);
704     uint64_t ParentCount = CurrentCount;
705     Visit(E->getLHS());
706     // Counter tracks the right hand side of a logical or operator.
707     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
708     CountMap[E->getRHS()] = RHSCount;
709     Visit(E->getRHS());
710     setCount(ParentCount + RHSCount - CurrentCount);
711     RecordNextStmtCount = true;
712   }
713 };
714 } // end anonymous namespace
715
716 void PGOHash::combine(HashType Type) {
717   // Check that we never combine 0 and only have six bits.
718   assert(Type && "Hash is invalid: unexpected type 0");
719   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
720
721   // Pass through MD5 if enough work has built up.
722   if (Count && Count % NumTypesPerWord == 0) {
723     using namespace llvm::support;
724     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
725     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
726     Working = 0;
727   }
728
729   // Accumulate the current type.
730   ++Count;
731   Working = Working << NumBitsPerType | Type;
732 }
733
734 uint64_t PGOHash::finalize() {
735   // Use Working as the hash directly if we never used MD5.
736   if (Count <= NumTypesPerWord)
737     // No need to byte swap here, since none of the math was endian-dependent.
738     // This number will be byte-swapped as required on endianness transitions,
739     // so we will see the same value on the other side.
740     return Working;
741
742   // Check for remaining work in Working.
743   if (Working)
744     MD5.update(Working);
745
746   // Finalize the MD5 and return the hash.
747   llvm::MD5::MD5Result Result;
748   MD5.final(Result);
749   using namespace llvm::support;
750   return Result.low();
751 }
752
753 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
754   const Decl *D = GD.getDecl();
755   if (!D->hasBody())
756     return;
757
758   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
759   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
760   if (!InstrumentRegions && !PGOReader)
761     return;
762   if (D->isImplicit())
763     return;
764   // Constructors and destructors may be represented by several functions in IR.
765   // If so, instrument only base variant, others are implemented by delegation
766   // to the base one, it would be counted twice otherwise.
767   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
768     if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
769       return;
770
771     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
772       if (GD.getCtorType() != Ctor_Base &&
773           CodeGenFunction::IsConstructorDelegationValid(CCD))
774         return;
775   }
776   CGM.ClearUnusedCoverageMapping(D);
777   setFuncName(Fn);
778
779   mapRegionCounters(D);
780   if (CGM.getCodeGenOpts().CoverageMapping)
781     emitCounterRegionMapping(D);
782   if (PGOReader) {
783     SourceManager &SM = CGM.getContext().getSourceManager();
784     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
785     computeRegionCounts(D);
786     applyFunctionAttributes(PGOReader, Fn);
787   }
788 }
789
790 void CodeGenPGO::mapRegionCounters(const Decl *D) {
791   // Use the latest hash version when inserting instrumentation, but use the
792   // version in the indexed profile if we're reading PGO data.
793   PGOHashVersion HashVersion = PGO_HASH_LATEST;
794   if (auto *PGOReader = CGM.getPGOReader())
795     HashVersion = getPGOHashVersion(PGOReader, CGM);
796
797   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
798   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
799   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
800     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
801   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
802     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
803   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
804     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
805   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
806     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
807   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
808   NumRegionCounters = Walker.NextCounter;
809   FunctionHash = Walker.Hash.finalize();
810 }
811
812 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
813   if (!D->getBody())
814     return true;
815
816   // Don't map the functions in system headers.
817   const auto &SM = CGM.getContext().getSourceManager();
818   auto Loc = D->getBody()->getLocStart();
819   return SM.isInSystemHeader(Loc);
820 }
821
822 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
823   if (skipRegionMappingForDecl(D))
824     return;
825
826   std::string CoverageMapping;
827   llvm::raw_string_ostream OS(CoverageMapping);
828   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
829                                 CGM.getContext().getSourceManager(),
830                                 CGM.getLangOpts(), RegionCounterMap.get());
831   MappingGen.emitCounterMapping(D, OS);
832   OS.flush();
833
834   if (CoverageMapping.empty())
835     return;
836
837   CGM.getCoverageMapping()->addFunctionMappingRecord(
838       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
839 }
840
841 void
842 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
843                                     llvm::GlobalValue::LinkageTypes Linkage) {
844   if (skipRegionMappingForDecl(D))
845     return;
846
847   std::string CoverageMapping;
848   llvm::raw_string_ostream OS(CoverageMapping);
849   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
850                                 CGM.getContext().getSourceManager(),
851                                 CGM.getLangOpts());
852   MappingGen.emitEmptyMapping(D, OS);
853   OS.flush();
854
855   if (CoverageMapping.empty())
856     return;
857
858   setFuncName(Name, Linkage);
859   CGM.getCoverageMapping()->addFunctionMappingRecord(
860       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
861 }
862
863 void CodeGenPGO::computeRegionCounts(const Decl *D) {
864   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
865   ComputeRegionCounts Walker(*StmtCountMap, *this);
866   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
867     Walker.VisitFunctionDecl(FD);
868   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
869     Walker.VisitObjCMethodDecl(MD);
870   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
871     Walker.VisitBlockDecl(BD);
872   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
873     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
874 }
875
876 void
877 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
878                                     llvm::Function *Fn) {
879   if (!haveRegionCounts())
880     return;
881
882   uint64_t FunctionCount = getRegionCount(nullptr);
883   Fn->setEntryCount(FunctionCount);
884 }
885
886 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
887                                       llvm::Value *StepV) {
888   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
889     return;
890   if (!Builder.GetInsertBlock())
891     return;
892
893   unsigned Counter = (*RegionCounterMap)[S];
894   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
895
896   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
897                          Builder.getInt64(FunctionHash),
898                          Builder.getInt32(NumRegionCounters),
899                          Builder.getInt32(Counter), StepV};
900   if (!StepV)
901     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
902                        makeArrayRef(Args, 4));
903   else
904     Builder.CreateCall(
905         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
906         makeArrayRef(Args));
907 }
908
909 // This method either inserts a call to the profile run-time during
910 // instrumentation or puts profile data into metadata for PGO use.
911 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
912     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
913
914   if (!EnableValueProfiling)
915     return;
916
917   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
918     return;
919
920   if (isa<llvm::Constant>(ValuePtr))
921     return;
922
923   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
924   if (InstrumentValueSites && RegionCounterMap) {
925     auto BuilderInsertPoint = Builder.saveIP();
926     Builder.SetInsertPoint(ValueSite);
927     llvm::Value *Args[5] = {
928         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
929         Builder.getInt64(FunctionHash),
930         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
931         Builder.getInt32(ValueKind),
932         Builder.getInt32(NumValueSites[ValueKind]++)
933     };
934     Builder.CreateCall(
935         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
936     Builder.restoreIP(BuilderInsertPoint);
937     return;
938   }
939
940   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
941   if (PGOReader && haveRegionCounts()) {
942     // We record the top most called three functions at each call site.
943     // Profile metadata contains "VP" string identifying this metadata
944     // as value profiling data, then a uint32_t value for the value profiling
945     // kind, a uint64_t value for the total number of times the call is
946     // executed, followed by the function hash and execution count (uint64_t)
947     // pairs for each function.
948     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
949       return;
950
951     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
952                             (llvm::InstrProfValueKind)ValueKind,
953                             NumValueSites[ValueKind]);
954
955     NumValueSites[ValueKind]++;
956   }
957 }
958
959 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
960                                   bool IsInMainFile) {
961   CGM.getPGOStats().addVisited(IsInMainFile);
962   RegionCounts.clear();
963   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
964       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
965   if (auto E = RecordExpected.takeError()) {
966     auto IPE = llvm::InstrProfError::take(std::move(E));
967     if (IPE == llvm::instrprof_error::unknown_function)
968       CGM.getPGOStats().addMissing(IsInMainFile);
969     else if (IPE == llvm::instrprof_error::hash_mismatch)
970       CGM.getPGOStats().addMismatched(IsInMainFile);
971     else if (IPE == llvm::instrprof_error::malformed)
972       // TODO: Consider a more specific warning for this case.
973       CGM.getPGOStats().addMismatched(IsInMainFile);
974     return;
975   }
976   ProfRecord =
977       llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
978   RegionCounts = ProfRecord->Counts;
979 }
980
981 /// Calculate what to divide by to scale weights.
982 ///
983 /// Given the maximum weight, calculate a divisor that will scale all the
984 /// weights to strictly less than UINT32_MAX.
985 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
986   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
987 }
988
989 /// Scale an individual branch weight (and add 1).
990 ///
991 /// Scale a 64-bit weight down to 32-bits using \c Scale.
992 ///
993 /// According to Laplace's Rule of Succession, it is better to compute the
994 /// weight based on the count plus 1, so universally add 1 to the value.
995 ///
996 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
997 /// greater than \c Weight.
998 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
999   assert(Scale && "scale by 0?");
1000   uint64_t Scaled = Weight / Scale + 1;
1001   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1002   return Scaled;
1003 }
1004
1005 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1006                                                     uint64_t FalseCount) {
1007   // Check for empty weights.
1008   if (!TrueCount && !FalseCount)
1009     return nullptr;
1010
1011   // Calculate how to scale down to 32-bits.
1012   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1013
1014   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1015   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1016                                       scaleBranchWeight(FalseCount, Scale));
1017 }
1018
1019 llvm::MDNode *
1020 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1021   // We need at least two elements to create meaningful weights.
1022   if (Weights.size() < 2)
1023     return nullptr;
1024
1025   // Check for empty weights.
1026   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1027   if (MaxWeight == 0)
1028     return nullptr;
1029
1030   // Calculate how to scale down to 32-bits.
1031   uint64_t Scale = calculateWeightScale(MaxWeight);
1032
1033   SmallVector<uint32_t, 16> ScaledWeights;
1034   ScaledWeights.reserve(Weights.size());
1035   for (uint64_t W : Weights)
1036     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1037
1038   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1039   return MDHelper.createBranchWeights(ScaledWeights);
1040 }
1041
1042 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1043                                                            uint64_t LoopCount) {
1044   if (!PGO.haveRegionCounts())
1045     return nullptr;
1046   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1047   assert(CondCount.hasValue() && "missing expected loop condition count");
1048   if (*CondCount == 0)
1049     return nullptr;
1050   return createProfileWeights(LoopCount,
1051                               std::max(*CondCount, LoopCount) - LoopCount);
1052 }