]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - lib/CodeGen/CodeGenPGO.cpp
Vendor import of clang trunk r238337:
[FreeBSD/FreeBSD.git] / lib / CodeGen / CodeGenPGO.cpp
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25
26 using namespace clang;
27 using namespace CodeGen;
28
29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   StringRef RawFuncName = Name;
32
33   // Function names may be prefixed with a binary '1' to indicate
34   // that the backend should not modify the symbols due to any platform
35   // naming convention. Do not include that '1' in the PGO profile name.
36   if (RawFuncName[0] == '\1')
37     RawFuncName = RawFuncName.substr(1);
38
39   FuncName = RawFuncName;
40   if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41     // For local symbols, prepend the main file name to distinguish them.
42     // Do not include the full path in the file name since there's no guarantee
43     // that it will stay the same, e.g., if the files are checked out from
44     // version control in different locations.
45     if (CGM.getCodeGenOpts().MainFileName.empty())
46       FuncName = FuncName.insert(0, "<unknown>:");
47     else
48       FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49   }
50
51   // If we're generating a profile, create a variable for the name.
52   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53     createFuncNameVar(Linkage);
54 }
55
56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57   setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59
60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61   // We generally want to match the function's linkage, but available_externally
62   // and extern_weak both have the wrong semantics, and anything that doesn't
63   // need to link across compilation units doesn't need to be visible at all.
64   if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
65     Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
66   else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
67     Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
68   else if (Linkage == llvm::GlobalValue::InternalLinkage ||
69            Linkage == llvm::GlobalValue::ExternalLinkage)
70     Linkage = llvm::GlobalValue::PrivateLinkage;
71
72   auto *Value =
73       llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
74   FuncNameVar =
75       new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
76                                Value, "__llvm_profile_name_" + FuncName);
77
78   // Hide the symbol so that we correctly get a copy for each executable.
79   if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
80     FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
81 }
82
83 namespace {
84 /// \brief Stable hasher for PGO region counters.
85 ///
86 /// PGOHash produces a stable hash of a given function's control flow.
87 ///
88 /// Changing the output of this hash will invalidate all previously generated
89 /// profiles -- i.e., don't do it.
90 ///
91 /// \note  When this hash does eventually change (years?), we still need to
92 /// support old hashes.  We'll need to pull in the version number from the
93 /// profile data format and use the matching hash function.
94 class PGOHash {
95   uint64_t Working;
96   unsigned Count;
97   llvm::MD5 MD5;
98
99   static const int NumBitsPerType = 6;
100   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
101   static const unsigned TooBig = 1u << NumBitsPerType;
102
103 public:
104   /// \brief Hash values for AST nodes.
105   ///
106   /// Distinct values for AST nodes that have region counters attached.
107   ///
108   /// These values must be stable.  All new members must be added at the end,
109   /// and no members should be removed.  Changing the enumeration value for an
110   /// AST node will affect the hash of every function that contains that node.
111   enum HashType : unsigned char {
112     None = 0,
113     LabelStmt = 1,
114     WhileStmt,
115     DoStmt,
116     ForStmt,
117     CXXForRangeStmt,
118     ObjCForCollectionStmt,
119     SwitchStmt,
120     CaseStmt,
121     DefaultStmt,
122     IfStmt,
123     CXXTryStmt,
124     CXXCatchStmt,
125     ConditionalOperator,
126     BinaryOperatorLAnd,
127     BinaryOperatorLOr,
128     BinaryConditionalOperator,
129
130     // Keep this last.  It's for the static assert that follows.
131     LastHashType
132   };
133   static_assert(LastHashType <= TooBig, "Too many types in HashType");
134
135   // TODO: When this format changes, take in a version number here, and use the
136   // old hash calculation for file formats that used the old hash.
137   PGOHash() : Working(0), Count(0) {}
138   void combine(HashType Type);
139   uint64_t finalize();
140 };
141 const int PGOHash::NumBitsPerType;
142 const unsigned PGOHash::NumTypesPerWord;
143 const unsigned PGOHash::TooBig;
144
145 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
146 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
147   /// The next counter value to assign.
148   unsigned NextCounter;
149   /// The function hash.
150   PGOHash Hash;
151   /// The map of statements to counters.
152   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
153
154   MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
155       : NextCounter(0), CounterMap(CounterMap) {}
156
157   // Blocks and lambdas are handled as separate functions, so we need not
158   // traverse them in the parent context.
159   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
160   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
161   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
162
163   bool VisitDecl(const Decl *D) {
164     switch (D->getKind()) {
165     default:
166       break;
167     case Decl::Function:
168     case Decl::CXXMethod:
169     case Decl::CXXConstructor:
170     case Decl::CXXDestructor:
171     case Decl::CXXConversion:
172     case Decl::ObjCMethod:
173     case Decl::Block:
174     case Decl::Captured:
175       CounterMap[D->getBody()] = NextCounter++;
176       break;
177     }
178     return true;
179   }
180
181   bool VisitStmt(const Stmt *S) {
182     auto Type = getHashType(S);
183     if (Type == PGOHash::None)
184       return true;
185
186     CounterMap[S] = NextCounter++;
187     Hash.combine(Type);
188     return true;
189   }
190   PGOHash::HashType getHashType(const Stmt *S) {
191     switch (S->getStmtClass()) {
192     default:
193       break;
194     case Stmt::LabelStmtClass:
195       return PGOHash::LabelStmt;
196     case Stmt::WhileStmtClass:
197       return PGOHash::WhileStmt;
198     case Stmt::DoStmtClass:
199       return PGOHash::DoStmt;
200     case Stmt::ForStmtClass:
201       return PGOHash::ForStmt;
202     case Stmt::CXXForRangeStmtClass:
203       return PGOHash::CXXForRangeStmt;
204     case Stmt::ObjCForCollectionStmtClass:
205       return PGOHash::ObjCForCollectionStmt;
206     case Stmt::SwitchStmtClass:
207       return PGOHash::SwitchStmt;
208     case Stmt::CaseStmtClass:
209       return PGOHash::CaseStmt;
210     case Stmt::DefaultStmtClass:
211       return PGOHash::DefaultStmt;
212     case Stmt::IfStmtClass:
213       return PGOHash::IfStmt;
214     case Stmt::CXXTryStmtClass:
215       return PGOHash::CXXTryStmt;
216     case Stmt::CXXCatchStmtClass:
217       return PGOHash::CXXCatchStmt;
218     case Stmt::ConditionalOperatorClass:
219       return PGOHash::ConditionalOperator;
220     case Stmt::BinaryConditionalOperatorClass:
221       return PGOHash::BinaryConditionalOperator;
222     case Stmt::BinaryOperatorClass: {
223       const BinaryOperator *BO = cast<BinaryOperator>(S);
224       if (BO->getOpcode() == BO_LAnd)
225         return PGOHash::BinaryOperatorLAnd;
226       if (BO->getOpcode() == BO_LOr)
227         return PGOHash::BinaryOperatorLOr;
228       break;
229     }
230     }
231     return PGOHash::None;
232   }
233 };
234
235 /// A StmtVisitor that propagates the raw counts through the AST and
236 /// records the count at statements where the value may change.
237 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
238   /// PGO state.
239   CodeGenPGO &PGO;
240
241   /// A flag that is set when the current count should be recorded on the
242   /// next statement, such as at the exit of a loop.
243   bool RecordNextStmtCount;
244
245   /// The count at the current location in the traversal.
246   uint64_t CurrentCount;
247
248   /// The map of statements to count values.
249   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
250
251   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
252   struct BreakContinue {
253     uint64_t BreakCount;
254     uint64_t ContinueCount;
255     BreakContinue() : BreakCount(0), ContinueCount(0) {}
256   };
257   SmallVector<BreakContinue, 8> BreakContinueStack;
258
259   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
260                       CodeGenPGO &PGO)
261       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
262
263   void RecordStmtCount(const Stmt *S) {
264     if (RecordNextStmtCount) {
265       CountMap[S] = CurrentCount;
266       RecordNextStmtCount = false;
267     }
268   }
269
270   /// Set and return the current count.
271   uint64_t setCount(uint64_t Count) {
272     CurrentCount = Count;
273     return Count;
274   }
275
276   void VisitStmt(const Stmt *S) {
277     RecordStmtCount(S);
278     for (Stmt::const_child_range I = S->children(); I; ++I) {
279       if (*I)
280         this->Visit(*I);
281     }
282   }
283
284   void VisitFunctionDecl(const FunctionDecl *D) {
285     // Counter tracks entry to the function body.
286     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
287     CountMap[D->getBody()] = BodyCount;
288     Visit(D->getBody());
289   }
290
291   // Skip lambda expressions. We visit these as FunctionDecls when we're
292   // generating them and aren't interested in the body when generating a
293   // parent context.
294   void VisitLambdaExpr(const LambdaExpr *LE) {}
295
296   void VisitCapturedDecl(const CapturedDecl *D) {
297     // Counter tracks entry to the capture body.
298     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
299     CountMap[D->getBody()] = BodyCount;
300     Visit(D->getBody());
301   }
302
303   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
304     // Counter tracks entry to the method body.
305     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
306     CountMap[D->getBody()] = BodyCount;
307     Visit(D->getBody());
308   }
309
310   void VisitBlockDecl(const BlockDecl *D) {
311     // Counter tracks entry to the block body.
312     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
313     CountMap[D->getBody()] = BodyCount;
314     Visit(D->getBody());
315   }
316
317   void VisitReturnStmt(const ReturnStmt *S) {
318     RecordStmtCount(S);
319     if (S->getRetValue())
320       Visit(S->getRetValue());
321     CurrentCount = 0;
322     RecordNextStmtCount = true;
323   }
324
325   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
326     RecordStmtCount(E);
327     if (E->getSubExpr())
328       Visit(E->getSubExpr());
329     CurrentCount = 0;
330     RecordNextStmtCount = true;
331   }
332
333   void VisitGotoStmt(const GotoStmt *S) {
334     RecordStmtCount(S);
335     CurrentCount = 0;
336     RecordNextStmtCount = true;
337   }
338
339   void VisitLabelStmt(const LabelStmt *S) {
340     RecordNextStmtCount = false;
341     // Counter tracks the block following the label.
342     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
343     CountMap[S] = BlockCount;
344     Visit(S->getSubStmt());
345   }
346
347   void VisitBreakStmt(const BreakStmt *S) {
348     RecordStmtCount(S);
349     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
350     BreakContinueStack.back().BreakCount += CurrentCount;
351     CurrentCount = 0;
352     RecordNextStmtCount = true;
353   }
354
355   void VisitContinueStmt(const ContinueStmt *S) {
356     RecordStmtCount(S);
357     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
358     BreakContinueStack.back().ContinueCount += CurrentCount;
359     CurrentCount = 0;
360     RecordNextStmtCount = true;
361   }
362
363   void VisitWhileStmt(const WhileStmt *S) {
364     RecordStmtCount(S);
365     uint64_t ParentCount = CurrentCount;
366
367     BreakContinueStack.push_back(BreakContinue());
368     // Visit the body region first so the break/continue adjustments can be
369     // included when visiting the condition.
370     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
371     CountMap[S->getBody()] = CurrentCount;
372     Visit(S->getBody());
373     uint64_t BackedgeCount = CurrentCount;
374
375     // ...then go back and propagate counts through the condition. The count
376     // at the start of the condition is the sum of the incoming edges,
377     // the backedge from the end of the loop body, and the edges from
378     // continue statements.
379     BreakContinue BC = BreakContinueStack.pop_back_val();
380     uint64_t CondCount =
381         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
382     CountMap[S->getCond()] = CondCount;
383     Visit(S->getCond());
384     setCount(BC.BreakCount + CondCount - BodyCount);
385     RecordNextStmtCount = true;
386   }
387
388   void VisitDoStmt(const DoStmt *S) {
389     RecordStmtCount(S);
390     uint64_t LoopCount = PGO.getRegionCount(S);
391
392     BreakContinueStack.push_back(BreakContinue());
393     // The count doesn't include the fallthrough from the parent scope. Add it.
394     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
395     CountMap[S->getBody()] = BodyCount;
396     Visit(S->getBody());
397     uint64_t BackedgeCount = CurrentCount;
398
399     BreakContinue BC = BreakContinueStack.pop_back_val();
400     // The count at the start of the condition is equal to the count at the
401     // end of the body, plus any continues.
402     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
403     CountMap[S->getCond()] = CondCount;
404     Visit(S->getCond());
405     setCount(BC.BreakCount + CondCount - LoopCount);
406     RecordNextStmtCount = true;
407   }
408
409   void VisitForStmt(const ForStmt *S) {
410     RecordStmtCount(S);
411     if (S->getInit())
412       Visit(S->getInit());
413
414     uint64_t ParentCount = CurrentCount;
415
416     BreakContinueStack.push_back(BreakContinue());
417     // Visit the body region first. (This is basically the same as a while
418     // loop; see further comments in VisitWhileStmt.)
419     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
420     CountMap[S->getBody()] = BodyCount;
421     Visit(S->getBody());
422     uint64_t BackedgeCount = CurrentCount;
423     BreakContinue BC = BreakContinueStack.pop_back_val();
424
425     // The increment is essentially part of the body but it needs to include
426     // the count for all the continue statements.
427     if (S->getInc()) {
428       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
429       CountMap[S->getInc()] = IncCount;
430       Visit(S->getInc());
431     }
432
433     // ...then go back and propagate counts through the condition.
434     uint64_t CondCount =
435         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
436     if (S->getCond()) {
437       CountMap[S->getCond()] = CondCount;
438       Visit(S->getCond());
439     }
440     setCount(BC.BreakCount + CondCount - BodyCount);
441     RecordNextStmtCount = true;
442   }
443
444   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
445     RecordStmtCount(S);
446     Visit(S->getLoopVarStmt());
447     Visit(S->getRangeStmt());
448     Visit(S->getBeginEndStmt());
449
450     uint64_t ParentCount = CurrentCount;
451     BreakContinueStack.push_back(BreakContinue());
452     // Visit the body region first. (This is basically the same as a while
453     // loop; see further comments in VisitWhileStmt.)
454     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
455     CountMap[S->getBody()] = BodyCount;
456     Visit(S->getBody());
457     uint64_t BackedgeCount = CurrentCount;
458     BreakContinue BC = BreakContinueStack.pop_back_val();
459
460     // The increment is essentially part of the body but it needs to include
461     // the count for all the continue statements.
462     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
463     CountMap[S->getInc()] = IncCount;
464     Visit(S->getInc());
465
466     // ...then go back and propagate counts through the condition.
467     uint64_t CondCount =
468         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
469     CountMap[S->getCond()] = CondCount;
470     Visit(S->getCond());
471     setCount(BC.BreakCount + CondCount - BodyCount);
472     RecordNextStmtCount = true;
473   }
474
475   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
476     RecordStmtCount(S);
477     Visit(S->getElement());
478     uint64_t ParentCount = CurrentCount;
479     BreakContinueStack.push_back(BreakContinue());
480     // Counter tracks the body of the loop.
481     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
482     CountMap[S->getBody()] = BodyCount;
483     Visit(S->getBody());
484     uint64_t BackedgeCount = CurrentCount;
485     BreakContinue BC = BreakContinueStack.pop_back_val();
486
487     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
488              BodyCount);
489     RecordNextStmtCount = true;
490   }
491
492   void VisitSwitchStmt(const SwitchStmt *S) {
493     RecordStmtCount(S);
494     Visit(S->getCond());
495     CurrentCount = 0;
496     BreakContinueStack.push_back(BreakContinue());
497     Visit(S->getBody());
498     // If the switch is inside a loop, add the continue counts.
499     BreakContinue BC = BreakContinueStack.pop_back_val();
500     if (!BreakContinueStack.empty())
501       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
502     // Counter tracks the exit block of the switch.
503     setCount(PGO.getRegionCount(S));
504     RecordNextStmtCount = true;
505   }
506
507   void VisitSwitchCase(const SwitchCase *S) {
508     RecordNextStmtCount = false;
509     // Counter for this particular case. This counts only jumps from the
510     // switch header and does not include fallthrough from the case before
511     // this one.
512     uint64_t CaseCount = PGO.getRegionCount(S);
513     setCount(CurrentCount + CaseCount);
514     // We need the count without fallthrough in the mapping, so it's more useful
515     // for branch probabilities.
516     CountMap[S] = CaseCount;
517     RecordNextStmtCount = true;
518     Visit(S->getSubStmt());
519   }
520
521   void VisitIfStmt(const IfStmt *S) {
522     RecordStmtCount(S);
523     uint64_t ParentCount = CurrentCount;
524     Visit(S->getCond());
525
526     // Counter tracks the "then" part of an if statement. The count for
527     // the "else" part, if it exists, will be calculated from this counter.
528     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
529     CountMap[S->getThen()] = ThenCount;
530     Visit(S->getThen());
531     uint64_t OutCount = CurrentCount;
532
533     uint64_t ElseCount = ParentCount - ThenCount;
534     if (S->getElse()) {
535       setCount(ElseCount);
536       CountMap[S->getElse()] = ElseCount;
537       Visit(S->getElse());
538       OutCount += CurrentCount;
539     } else
540       OutCount += ElseCount;
541     setCount(OutCount);
542     RecordNextStmtCount = true;
543   }
544
545   void VisitCXXTryStmt(const CXXTryStmt *S) {
546     RecordStmtCount(S);
547     Visit(S->getTryBlock());
548     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
549       Visit(S->getHandler(I));
550     // Counter tracks the continuation block of the try statement.
551     setCount(PGO.getRegionCount(S));
552     RecordNextStmtCount = true;
553   }
554
555   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
556     RecordNextStmtCount = false;
557     // Counter tracks the catch statement's handler block.
558     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
559     CountMap[S] = CatchCount;
560     Visit(S->getHandlerBlock());
561   }
562
563   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
564     RecordStmtCount(E);
565     uint64_t ParentCount = CurrentCount;
566     Visit(E->getCond());
567
568     // Counter tracks the "true" part of a conditional operator. The
569     // count in the "false" part will be calculated from this counter.
570     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
571     CountMap[E->getTrueExpr()] = TrueCount;
572     Visit(E->getTrueExpr());
573     uint64_t OutCount = CurrentCount;
574
575     uint64_t FalseCount = setCount(ParentCount - TrueCount);
576     CountMap[E->getFalseExpr()] = FalseCount;
577     Visit(E->getFalseExpr());
578     OutCount += CurrentCount;
579
580     setCount(OutCount);
581     RecordNextStmtCount = true;
582   }
583
584   void VisitBinLAnd(const BinaryOperator *E) {
585     RecordStmtCount(E);
586     uint64_t ParentCount = CurrentCount;
587     Visit(E->getLHS());
588     // Counter tracks the right hand side of a logical and operator.
589     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
590     CountMap[E->getRHS()] = RHSCount;
591     Visit(E->getRHS());
592     setCount(ParentCount + RHSCount - CurrentCount);
593     RecordNextStmtCount = true;
594   }
595
596   void VisitBinLOr(const BinaryOperator *E) {
597     RecordStmtCount(E);
598     uint64_t ParentCount = CurrentCount;
599     Visit(E->getLHS());
600     // Counter tracks the right hand side of a logical or operator.
601     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
602     CountMap[E->getRHS()] = RHSCount;
603     Visit(E->getRHS());
604     setCount(ParentCount + RHSCount - CurrentCount);
605     RecordNextStmtCount = true;
606   }
607 };
608 }
609
610 void PGOHash::combine(HashType Type) {
611   // Check that we never combine 0 and only have six bits.
612   assert(Type && "Hash is invalid: unexpected type 0");
613   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
614
615   // Pass through MD5 if enough work has built up.
616   if (Count && Count % NumTypesPerWord == 0) {
617     using namespace llvm::support;
618     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
619     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
620     Working = 0;
621   }
622
623   // Accumulate the current type.
624   ++Count;
625   Working = Working << NumBitsPerType | Type;
626 }
627
628 uint64_t PGOHash::finalize() {
629   // Use Working as the hash directly if we never used MD5.
630   if (Count <= NumTypesPerWord)
631     // No need to byte swap here, since none of the math was endian-dependent.
632     // This number will be byte-swapped as required on endianness transitions,
633     // so we will see the same value on the other side.
634     return Working;
635
636   // Check for remaining work in Working.
637   if (Working)
638     MD5.update(Working);
639
640   // Finalize the MD5 and return the hash.
641   llvm::MD5::MD5Result Result;
642   MD5.final(Result);
643   using namespace llvm::support;
644   return endian::read<uint64_t, little, unaligned>(Result);
645 }
646
647 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
648   // Make sure we only emit coverage mapping for one constructor/destructor.
649   // Clang emits several functions for the constructor and the destructor of
650   // a class. Every function is instrumented, but we only want to provide
651   // coverage for one of them. Because of that we only emit the coverage mapping
652   // for the base constructor/destructor.
653   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
654        GD.getCtorType() != Ctor_Base) ||
655       (isa<CXXDestructorDecl>(GD.getDecl()) &&
656        GD.getDtorType() != Dtor_Base)) {
657     SkipCoverageMapping = true;
658   }
659 }
660
661 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
662   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
663   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
664   if (!InstrumentRegions && !PGOReader)
665     return;
666   if (D->isImplicit())
667     return;
668   CGM.ClearUnusedCoverageMapping(D);
669   setFuncName(Fn);
670
671   mapRegionCounters(D);
672   if (CGM.getCodeGenOpts().CoverageMapping)
673     emitCounterRegionMapping(D);
674   if (PGOReader) {
675     SourceManager &SM = CGM.getContext().getSourceManager();
676     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
677     computeRegionCounts(D);
678     applyFunctionAttributes(PGOReader, Fn);
679   }
680 }
681
682 void CodeGenPGO::mapRegionCounters(const Decl *D) {
683   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
684   MapRegionCounters Walker(*RegionCounterMap);
685   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
686     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
687   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
688     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
689   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
690     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
691   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
692     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
693   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
694   NumRegionCounters = Walker.NextCounter;
695   FunctionHash = Walker.Hash.finalize();
696 }
697
698 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
699   if (SkipCoverageMapping)
700     return;
701   // Don't map the functions inside the system headers
702   auto Loc = D->getBody()->getLocStart();
703   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
704     return;
705
706   std::string CoverageMapping;
707   llvm::raw_string_ostream OS(CoverageMapping);
708   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
709                                 CGM.getContext().getSourceManager(),
710                                 CGM.getLangOpts(), RegionCounterMap.get());
711   MappingGen.emitCounterMapping(D, OS);
712   OS.flush();
713
714   if (CoverageMapping.empty())
715     return;
716
717   CGM.getCoverageMapping()->addFunctionMappingRecord(
718       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
719 }
720
721 void
722 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
723                                     llvm::GlobalValue::LinkageTypes Linkage) {
724   if (SkipCoverageMapping)
725     return;
726   // Don't map the functions inside the system headers
727   auto Loc = D->getBody()->getLocStart();
728   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
729     return;
730
731   std::string CoverageMapping;
732   llvm::raw_string_ostream OS(CoverageMapping);
733   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
734                                 CGM.getContext().getSourceManager(),
735                                 CGM.getLangOpts());
736   MappingGen.emitEmptyMapping(D, OS);
737   OS.flush();
738
739   if (CoverageMapping.empty())
740     return;
741
742   setFuncName(Name, Linkage);
743   CGM.getCoverageMapping()->addFunctionMappingRecord(
744       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
745 }
746
747 void CodeGenPGO::computeRegionCounts(const Decl *D) {
748   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
749   ComputeRegionCounts Walker(*StmtCountMap, *this);
750   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
751     Walker.VisitFunctionDecl(FD);
752   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
753     Walker.VisitObjCMethodDecl(MD);
754   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
755     Walker.VisitBlockDecl(BD);
756   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
757     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
758 }
759
760 void
761 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
762                                     llvm::Function *Fn) {
763   if (!haveRegionCounts())
764     return;
765
766   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
767   uint64_t FunctionCount = getRegionCount(0);
768   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
769     // Turn on InlineHint attribute for hot functions.
770     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
771     Fn->addFnAttr(llvm::Attribute::InlineHint);
772   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
773     // Turn on Cold attribute for cold functions.
774     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
775     Fn->addFnAttr(llvm::Attribute::Cold);
776 }
777
778 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
779   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
780     return;
781   if (!Builder.GetInsertPoint())
782     return;
783
784   unsigned Counter = (*RegionCounterMap)[S];
785   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
786   Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
787                      {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
788                       Builder.getInt64(FunctionHash),
789                       Builder.getInt32(NumRegionCounters),
790                       Builder.getInt32(Counter)});
791 }
792
793 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
794                                   bool IsInMainFile) {
795   CGM.getPGOStats().addVisited(IsInMainFile);
796   RegionCounts.clear();
797   if (std::error_code EC =
798           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
799     if (EC == llvm::instrprof_error::unknown_function)
800       CGM.getPGOStats().addMissing(IsInMainFile);
801     else if (EC == llvm::instrprof_error::hash_mismatch)
802       CGM.getPGOStats().addMismatched(IsInMainFile);
803     else if (EC == llvm::instrprof_error::malformed)
804       // TODO: Consider a more specific warning for this case.
805       CGM.getPGOStats().addMismatched(IsInMainFile);
806     RegionCounts.clear();
807   }
808 }
809
810 /// \brief Calculate what to divide by to scale weights.
811 ///
812 /// Given the maximum weight, calculate a divisor that will scale all the
813 /// weights to strictly less than UINT32_MAX.
814 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
815   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
816 }
817
818 /// \brief Scale an individual branch weight (and add 1).
819 ///
820 /// Scale a 64-bit weight down to 32-bits using \c Scale.
821 ///
822 /// According to Laplace's Rule of Succession, it is better to compute the
823 /// weight based on the count plus 1, so universally add 1 to the value.
824 ///
825 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
826 /// greater than \c Weight.
827 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
828   assert(Scale && "scale by 0?");
829   uint64_t Scaled = Weight / Scale + 1;
830   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
831   return Scaled;
832 }
833
834 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
835                                                     uint64_t FalseCount) {
836   // Check for empty weights.
837   if (!TrueCount && !FalseCount)
838     return nullptr;
839
840   // Calculate how to scale down to 32-bits.
841   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
842
843   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
844   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
845                                       scaleBranchWeight(FalseCount, Scale));
846 }
847
848 llvm::MDNode *
849 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
850   // We need at least two elements to create meaningful weights.
851   if (Weights.size() < 2)
852     return nullptr;
853
854   // Check for empty weights.
855   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
856   if (MaxWeight == 0)
857     return nullptr;
858
859   // Calculate how to scale down to 32-bits.
860   uint64_t Scale = calculateWeightScale(MaxWeight);
861
862   SmallVector<uint32_t, 16> ScaledWeights;
863   ScaledWeights.reserve(Weights.size());
864   for (uint64_t W : Weights)
865     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
866
867   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
868   return MDHelper.createBranchWeights(ScaledWeights);
869 }
870
871 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
872                                                            uint64_t LoopCount) {
873   if (!PGO.haveRegionCounts())
874     return nullptr;
875   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
876   assert(CondCount.hasValue() && "missing expected loop condition count");
877   if (*CondCount == 0)
878     return nullptr;
879   return createProfileWeights(LoopCount,
880                               std::max(*CondCount, LoopCount) - LoopCount);
881 }