]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp
Integrate tools/regression/acltools into tests/sys/acl
[FreeBSD/FreeBSD.git] / contrib / llvm / tools / clang / lib / CodeGen / CodeGenPGO.cpp
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25
26 using namespace clang;
27 using namespace CodeGen;
28
29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   StringRef RawFuncName = Name;
32
33   // Function names may be prefixed with a binary '1' to indicate
34   // that the backend should not modify the symbols due to any platform
35   // naming convention. Do not include that '1' in the PGO profile name.
36   if (RawFuncName[0] == '\1')
37     RawFuncName = RawFuncName.substr(1);
38
39   FuncName = RawFuncName;
40   if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41     // For local symbols, prepend the main file name to distinguish them.
42     // Do not include the full path in the file name since there's no guarantee
43     // that it will stay the same, e.g., if the files are checked out from
44     // version control in different locations.
45     if (CGM.getCodeGenOpts().MainFileName.empty())
46       FuncName = FuncName.insert(0, "<unknown>:");
47     else
48       FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49   }
50
51   // If we're generating a profile, create a variable for the name.
52   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53     createFuncNameVar(Linkage);
54 }
55
56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57   setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59
60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61   // We generally want to match the function's linkage, but available_externally
62   // and extern_weak both have the wrong semantics, and anything that doesn't
63   // need to link across compilation units doesn't need to be visible at all.
64   if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
65     Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
66   else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
67     Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
68   else if (Linkage == llvm::GlobalValue::InternalLinkage ||
69            Linkage == llvm::GlobalValue::ExternalLinkage)
70     Linkage = llvm::GlobalValue::PrivateLinkage;
71
72   auto *Value =
73       llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
74   FuncNameVar =
75       new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
76                                Value, "__llvm_profile_name_" + FuncName);
77
78   // Hide the symbol so that we correctly get a copy for each executable.
79   if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
80     FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
81 }
82
83 namespace {
84 /// \brief Stable hasher for PGO region counters.
85 ///
86 /// PGOHash produces a stable hash of a given function's control flow.
87 ///
88 /// Changing the output of this hash will invalidate all previously generated
89 /// profiles -- i.e., don't do it.
90 ///
91 /// \note  When this hash does eventually change (years?), we still need to
92 /// support old hashes.  We'll need to pull in the version number from the
93 /// profile data format and use the matching hash function.
94 class PGOHash {
95   uint64_t Working;
96   unsigned Count;
97   llvm::MD5 MD5;
98
99   static const int NumBitsPerType = 6;
100   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
101   static const unsigned TooBig = 1u << NumBitsPerType;
102
103 public:
104   /// \brief Hash values for AST nodes.
105   ///
106   /// Distinct values for AST nodes that have region counters attached.
107   ///
108   /// These values must be stable.  All new members must be added at the end,
109   /// and no members should be removed.  Changing the enumeration value for an
110   /// AST node will affect the hash of every function that contains that node.
111   enum HashType : unsigned char {
112     None = 0,
113     LabelStmt = 1,
114     WhileStmt,
115     DoStmt,
116     ForStmt,
117     CXXForRangeStmt,
118     ObjCForCollectionStmt,
119     SwitchStmt,
120     CaseStmt,
121     DefaultStmt,
122     IfStmt,
123     CXXTryStmt,
124     CXXCatchStmt,
125     ConditionalOperator,
126     BinaryOperatorLAnd,
127     BinaryOperatorLOr,
128     BinaryConditionalOperator,
129
130     // Keep this last.  It's for the static assert that follows.
131     LastHashType
132   };
133   static_assert(LastHashType <= TooBig, "Too many types in HashType");
134
135   // TODO: When this format changes, take in a version number here, and use the
136   // old hash calculation for file formats that used the old hash.
137   PGOHash() : Working(0), Count(0) {}
138   void combine(HashType Type);
139   uint64_t finalize();
140 };
141 const int PGOHash::NumBitsPerType;
142 const unsigned PGOHash::NumTypesPerWord;
143 const unsigned PGOHash::TooBig;
144
145 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
146 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
147   /// The next counter value to assign.
148   unsigned NextCounter;
149   /// The function hash.
150   PGOHash Hash;
151   /// The map of statements to counters.
152   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
153
154   MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
155       : NextCounter(0), CounterMap(CounterMap) {}
156
157   // Blocks and lambdas are handled as separate functions, so we need not
158   // traverse them in the parent context.
159   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
160   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
161   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
162
163   bool VisitDecl(const Decl *D) {
164     switch (D->getKind()) {
165     default:
166       break;
167     case Decl::Function:
168     case Decl::CXXMethod:
169     case Decl::CXXConstructor:
170     case Decl::CXXDestructor:
171     case Decl::CXXConversion:
172     case Decl::ObjCMethod:
173     case Decl::Block:
174     case Decl::Captured:
175       CounterMap[D->getBody()] = NextCounter++;
176       break;
177     }
178     return true;
179   }
180
181   bool VisitStmt(const Stmt *S) {
182     auto Type = getHashType(S);
183     if (Type == PGOHash::None)
184       return true;
185
186     CounterMap[S] = NextCounter++;
187     Hash.combine(Type);
188     return true;
189   }
190   PGOHash::HashType getHashType(const Stmt *S) {
191     switch (S->getStmtClass()) {
192     default:
193       break;
194     case Stmt::LabelStmtClass:
195       return PGOHash::LabelStmt;
196     case Stmt::WhileStmtClass:
197       return PGOHash::WhileStmt;
198     case Stmt::DoStmtClass:
199       return PGOHash::DoStmt;
200     case Stmt::ForStmtClass:
201       return PGOHash::ForStmt;
202     case Stmt::CXXForRangeStmtClass:
203       return PGOHash::CXXForRangeStmt;
204     case Stmt::ObjCForCollectionStmtClass:
205       return PGOHash::ObjCForCollectionStmt;
206     case Stmt::SwitchStmtClass:
207       return PGOHash::SwitchStmt;
208     case Stmt::CaseStmtClass:
209       return PGOHash::CaseStmt;
210     case Stmt::DefaultStmtClass:
211       return PGOHash::DefaultStmt;
212     case Stmt::IfStmtClass:
213       return PGOHash::IfStmt;
214     case Stmt::CXXTryStmtClass:
215       return PGOHash::CXXTryStmt;
216     case Stmt::CXXCatchStmtClass:
217       return PGOHash::CXXCatchStmt;
218     case Stmt::ConditionalOperatorClass:
219       return PGOHash::ConditionalOperator;
220     case Stmt::BinaryConditionalOperatorClass:
221       return PGOHash::BinaryConditionalOperator;
222     case Stmt::BinaryOperatorClass: {
223       const BinaryOperator *BO = cast<BinaryOperator>(S);
224       if (BO->getOpcode() == BO_LAnd)
225         return PGOHash::BinaryOperatorLAnd;
226       if (BO->getOpcode() == BO_LOr)
227         return PGOHash::BinaryOperatorLOr;
228       break;
229     }
230     }
231     return PGOHash::None;
232   }
233 };
234
235 /// A StmtVisitor that propagates the raw counts through the AST and
236 /// records the count at statements where the value may change.
237 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
238   /// PGO state.
239   CodeGenPGO &PGO;
240
241   /// A flag that is set when the current count should be recorded on the
242   /// next statement, such as at the exit of a loop.
243   bool RecordNextStmtCount;
244
245   /// The count at the current location in the traversal.
246   uint64_t CurrentCount;
247
248   /// The map of statements to count values.
249   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
250
251   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
252   struct BreakContinue {
253     uint64_t BreakCount;
254     uint64_t ContinueCount;
255     BreakContinue() : BreakCount(0), ContinueCount(0) {}
256   };
257   SmallVector<BreakContinue, 8> BreakContinueStack;
258
259   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
260                       CodeGenPGO &PGO)
261       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
262
263   void RecordStmtCount(const Stmt *S) {
264     if (RecordNextStmtCount) {
265       CountMap[S] = CurrentCount;
266       RecordNextStmtCount = false;
267     }
268   }
269
270   /// Set and return the current count.
271   uint64_t setCount(uint64_t Count) {
272     CurrentCount = Count;
273     return Count;
274   }
275
276   void VisitStmt(const Stmt *S) {
277     RecordStmtCount(S);
278     for (const Stmt *Child : S->children())
279       if (Child)
280         this->Visit(Child);
281   }
282
283   void VisitFunctionDecl(const FunctionDecl *D) {
284     // Counter tracks entry to the function body.
285     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
286     CountMap[D->getBody()] = BodyCount;
287     Visit(D->getBody());
288   }
289
290   // Skip lambda expressions. We visit these as FunctionDecls when we're
291   // generating them and aren't interested in the body when generating a
292   // parent context.
293   void VisitLambdaExpr(const LambdaExpr *LE) {}
294
295   void VisitCapturedDecl(const CapturedDecl *D) {
296     // Counter tracks entry to the capture body.
297     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
298     CountMap[D->getBody()] = BodyCount;
299     Visit(D->getBody());
300   }
301
302   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
303     // Counter tracks entry to the method body.
304     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
305     CountMap[D->getBody()] = BodyCount;
306     Visit(D->getBody());
307   }
308
309   void VisitBlockDecl(const BlockDecl *D) {
310     // Counter tracks entry to the block body.
311     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
312     CountMap[D->getBody()] = BodyCount;
313     Visit(D->getBody());
314   }
315
316   void VisitReturnStmt(const ReturnStmt *S) {
317     RecordStmtCount(S);
318     if (S->getRetValue())
319       Visit(S->getRetValue());
320     CurrentCount = 0;
321     RecordNextStmtCount = true;
322   }
323
324   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
325     RecordStmtCount(E);
326     if (E->getSubExpr())
327       Visit(E->getSubExpr());
328     CurrentCount = 0;
329     RecordNextStmtCount = true;
330   }
331
332   void VisitGotoStmt(const GotoStmt *S) {
333     RecordStmtCount(S);
334     CurrentCount = 0;
335     RecordNextStmtCount = true;
336   }
337
338   void VisitLabelStmt(const LabelStmt *S) {
339     RecordNextStmtCount = false;
340     // Counter tracks the block following the label.
341     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
342     CountMap[S] = BlockCount;
343     Visit(S->getSubStmt());
344   }
345
346   void VisitBreakStmt(const BreakStmt *S) {
347     RecordStmtCount(S);
348     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
349     BreakContinueStack.back().BreakCount += CurrentCount;
350     CurrentCount = 0;
351     RecordNextStmtCount = true;
352   }
353
354   void VisitContinueStmt(const ContinueStmt *S) {
355     RecordStmtCount(S);
356     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
357     BreakContinueStack.back().ContinueCount += CurrentCount;
358     CurrentCount = 0;
359     RecordNextStmtCount = true;
360   }
361
362   void VisitWhileStmt(const WhileStmt *S) {
363     RecordStmtCount(S);
364     uint64_t ParentCount = CurrentCount;
365
366     BreakContinueStack.push_back(BreakContinue());
367     // Visit the body region first so the break/continue adjustments can be
368     // included when visiting the condition.
369     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
370     CountMap[S->getBody()] = CurrentCount;
371     Visit(S->getBody());
372     uint64_t BackedgeCount = CurrentCount;
373
374     // ...then go back and propagate counts through the condition. The count
375     // at the start of the condition is the sum of the incoming edges,
376     // the backedge from the end of the loop body, and the edges from
377     // continue statements.
378     BreakContinue BC = BreakContinueStack.pop_back_val();
379     uint64_t CondCount =
380         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
381     CountMap[S->getCond()] = CondCount;
382     Visit(S->getCond());
383     setCount(BC.BreakCount + CondCount - BodyCount);
384     RecordNextStmtCount = true;
385   }
386
387   void VisitDoStmt(const DoStmt *S) {
388     RecordStmtCount(S);
389     uint64_t LoopCount = PGO.getRegionCount(S);
390
391     BreakContinueStack.push_back(BreakContinue());
392     // The count doesn't include the fallthrough from the parent scope. Add it.
393     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
394     CountMap[S->getBody()] = BodyCount;
395     Visit(S->getBody());
396     uint64_t BackedgeCount = CurrentCount;
397
398     BreakContinue BC = BreakContinueStack.pop_back_val();
399     // The count at the start of the condition is equal to the count at the
400     // end of the body, plus any continues.
401     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
402     CountMap[S->getCond()] = CondCount;
403     Visit(S->getCond());
404     setCount(BC.BreakCount + CondCount - LoopCount);
405     RecordNextStmtCount = true;
406   }
407
408   void VisitForStmt(const ForStmt *S) {
409     RecordStmtCount(S);
410     if (S->getInit())
411       Visit(S->getInit());
412
413     uint64_t ParentCount = CurrentCount;
414
415     BreakContinueStack.push_back(BreakContinue());
416     // Visit the body region first. (This is basically the same as a while
417     // loop; see further comments in VisitWhileStmt.)
418     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
419     CountMap[S->getBody()] = BodyCount;
420     Visit(S->getBody());
421     uint64_t BackedgeCount = CurrentCount;
422     BreakContinue BC = BreakContinueStack.pop_back_val();
423
424     // The increment is essentially part of the body but it needs to include
425     // the count for all the continue statements.
426     if (S->getInc()) {
427       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
428       CountMap[S->getInc()] = IncCount;
429       Visit(S->getInc());
430     }
431
432     // ...then go back and propagate counts through the condition.
433     uint64_t CondCount =
434         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
435     if (S->getCond()) {
436       CountMap[S->getCond()] = CondCount;
437       Visit(S->getCond());
438     }
439     setCount(BC.BreakCount + CondCount - BodyCount);
440     RecordNextStmtCount = true;
441   }
442
443   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
444     RecordStmtCount(S);
445     Visit(S->getLoopVarStmt());
446     Visit(S->getRangeStmt());
447     Visit(S->getBeginEndStmt());
448
449     uint64_t ParentCount = CurrentCount;
450     BreakContinueStack.push_back(BreakContinue());
451     // Visit the body region first. (This is basically the same as a while
452     // loop; see further comments in VisitWhileStmt.)
453     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
454     CountMap[S->getBody()] = BodyCount;
455     Visit(S->getBody());
456     uint64_t BackedgeCount = CurrentCount;
457     BreakContinue BC = BreakContinueStack.pop_back_val();
458
459     // The increment is essentially part of the body but it needs to include
460     // the count for all the continue statements.
461     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
462     CountMap[S->getInc()] = IncCount;
463     Visit(S->getInc());
464
465     // ...then go back and propagate counts through the condition.
466     uint64_t CondCount =
467         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
468     CountMap[S->getCond()] = CondCount;
469     Visit(S->getCond());
470     setCount(BC.BreakCount + CondCount - BodyCount);
471     RecordNextStmtCount = true;
472   }
473
474   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
475     RecordStmtCount(S);
476     Visit(S->getElement());
477     uint64_t ParentCount = CurrentCount;
478     BreakContinueStack.push_back(BreakContinue());
479     // Counter tracks the body of the loop.
480     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
481     CountMap[S->getBody()] = BodyCount;
482     Visit(S->getBody());
483     uint64_t BackedgeCount = CurrentCount;
484     BreakContinue BC = BreakContinueStack.pop_back_val();
485
486     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
487              BodyCount);
488     RecordNextStmtCount = true;
489   }
490
491   void VisitSwitchStmt(const SwitchStmt *S) {
492     RecordStmtCount(S);
493     Visit(S->getCond());
494     CurrentCount = 0;
495     BreakContinueStack.push_back(BreakContinue());
496     Visit(S->getBody());
497     // If the switch is inside a loop, add the continue counts.
498     BreakContinue BC = BreakContinueStack.pop_back_val();
499     if (!BreakContinueStack.empty())
500       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
501     // Counter tracks the exit block of the switch.
502     setCount(PGO.getRegionCount(S));
503     RecordNextStmtCount = true;
504   }
505
506   void VisitSwitchCase(const SwitchCase *S) {
507     RecordNextStmtCount = false;
508     // Counter for this particular case. This counts only jumps from the
509     // switch header and does not include fallthrough from the case before
510     // this one.
511     uint64_t CaseCount = PGO.getRegionCount(S);
512     setCount(CurrentCount + CaseCount);
513     // We need the count without fallthrough in the mapping, so it's more useful
514     // for branch probabilities.
515     CountMap[S] = CaseCount;
516     RecordNextStmtCount = true;
517     Visit(S->getSubStmt());
518   }
519
520   void VisitIfStmt(const IfStmt *S) {
521     RecordStmtCount(S);
522     uint64_t ParentCount = CurrentCount;
523     Visit(S->getCond());
524
525     // Counter tracks the "then" part of an if statement. The count for
526     // the "else" part, if it exists, will be calculated from this counter.
527     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
528     CountMap[S->getThen()] = ThenCount;
529     Visit(S->getThen());
530     uint64_t OutCount = CurrentCount;
531
532     uint64_t ElseCount = ParentCount - ThenCount;
533     if (S->getElse()) {
534       setCount(ElseCount);
535       CountMap[S->getElse()] = ElseCount;
536       Visit(S->getElse());
537       OutCount += CurrentCount;
538     } else
539       OutCount += ElseCount;
540     setCount(OutCount);
541     RecordNextStmtCount = true;
542   }
543
544   void VisitCXXTryStmt(const CXXTryStmt *S) {
545     RecordStmtCount(S);
546     Visit(S->getTryBlock());
547     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
548       Visit(S->getHandler(I));
549     // Counter tracks the continuation block of the try statement.
550     setCount(PGO.getRegionCount(S));
551     RecordNextStmtCount = true;
552   }
553
554   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
555     RecordNextStmtCount = false;
556     // Counter tracks the catch statement's handler block.
557     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
558     CountMap[S] = CatchCount;
559     Visit(S->getHandlerBlock());
560   }
561
562   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
563     RecordStmtCount(E);
564     uint64_t ParentCount = CurrentCount;
565     Visit(E->getCond());
566
567     // Counter tracks the "true" part of a conditional operator. The
568     // count in the "false" part will be calculated from this counter.
569     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
570     CountMap[E->getTrueExpr()] = TrueCount;
571     Visit(E->getTrueExpr());
572     uint64_t OutCount = CurrentCount;
573
574     uint64_t FalseCount = setCount(ParentCount - TrueCount);
575     CountMap[E->getFalseExpr()] = FalseCount;
576     Visit(E->getFalseExpr());
577     OutCount += CurrentCount;
578
579     setCount(OutCount);
580     RecordNextStmtCount = true;
581   }
582
583   void VisitBinLAnd(const BinaryOperator *E) {
584     RecordStmtCount(E);
585     uint64_t ParentCount = CurrentCount;
586     Visit(E->getLHS());
587     // Counter tracks the right hand side of a logical and operator.
588     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
589     CountMap[E->getRHS()] = RHSCount;
590     Visit(E->getRHS());
591     setCount(ParentCount + RHSCount - CurrentCount);
592     RecordNextStmtCount = true;
593   }
594
595   void VisitBinLOr(const BinaryOperator *E) {
596     RecordStmtCount(E);
597     uint64_t ParentCount = CurrentCount;
598     Visit(E->getLHS());
599     // Counter tracks the right hand side of a logical or operator.
600     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
601     CountMap[E->getRHS()] = RHSCount;
602     Visit(E->getRHS());
603     setCount(ParentCount + RHSCount - CurrentCount);
604     RecordNextStmtCount = true;
605   }
606 };
607 }
608
609 void PGOHash::combine(HashType Type) {
610   // Check that we never combine 0 and only have six bits.
611   assert(Type && "Hash is invalid: unexpected type 0");
612   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
613
614   // Pass through MD5 if enough work has built up.
615   if (Count && Count % NumTypesPerWord == 0) {
616     using namespace llvm::support;
617     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
618     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
619     Working = 0;
620   }
621
622   // Accumulate the current type.
623   ++Count;
624   Working = Working << NumBitsPerType | Type;
625 }
626
627 uint64_t PGOHash::finalize() {
628   // Use Working as the hash directly if we never used MD5.
629   if (Count <= NumTypesPerWord)
630     // No need to byte swap here, since none of the math was endian-dependent.
631     // This number will be byte-swapped as required on endianness transitions,
632     // so we will see the same value on the other side.
633     return Working;
634
635   // Check for remaining work in Working.
636   if (Working)
637     MD5.update(Working);
638
639   // Finalize the MD5 and return the hash.
640   llvm::MD5::MD5Result Result;
641   MD5.final(Result);
642   using namespace llvm::support;
643   return endian::read<uint64_t, little, unaligned>(Result);
644 }
645
646 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
647   // Make sure we only emit coverage mapping for one constructor/destructor.
648   // Clang emits several functions for the constructor and the destructor of
649   // a class. Every function is instrumented, but we only want to provide
650   // coverage for one of them. Because of that we only emit the coverage mapping
651   // for the base constructor/destructor.
652   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
653        GD.getCtorType() != Ctor_Base) ||
654       (isa<CXXDestructorDecl>(GD.getDecl()) &&
655        GD.getDtorType() != Dtor_Base)) {
656     SkipCoverageMapping = true;
657   }
658 }
659
660 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
661   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
662   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
663   if (!InstrumentRegions && !PGOReader)
664     return;
665   if (D->isImplicit())
666     return;
667   CGM.ClearUnusedCoverageMapping(D);
668   setFuncName(Fn);
669
670   mapRegionCounters(D);
671   if (CGM.getCodeGenOpts().CoverageMapping)
672     emitCounterRegionMapping(D);
673   if (PGOReader) {
674     SourceManager &SM = CGM.getContext().getSourceManager();
675     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
676     computeRegionCounts(D);
677     applyFunctionAttributes(PGOReader, Fn);
678   }
679 }
680
681 void CodeGenPGO::mapRegionCounters(const Decl *D) {
682   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
683   MapRegionCounters Walker(*RegionCounterMap);
684   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
685     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
686   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
687     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
688   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
689     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
690   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
691     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
692   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
693   NumRegionCounters = Walker.NextCounter;
694   FunctionHash = Walker.Hash.finalize();
695 }
696
697 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
698   if (SkipCoverageMapping)
699     return;
700   // Don't map the functions inside the system headers
701   auto Loc = D->getBody()->getLocStart();
702   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
703     return;
704
705   std::string CoverageMapping;
706   llvm::raw_string_ostream OS(CoverageMapping);
707   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
708                                 CGM.getContext().getSourceManager(),
709                                 CGM.getLangOpts(), RegionCounterMap.get());
710   MappingGen.emitCounterMapping(D, OS);
711   OS.flush();
712
713   if (CoverageMapping.empty())
714     return;
715
716   CGM.getCoverageMapping()->addFunctionMappingRecord(
717       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
718 }
719
720 void
721 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
722                                     llvm::GlobalValue::LinkageTypes Linkage) {
723   if (SkipCoverageMapping)
724     return;
725   // Don't map the functions inside the system headers
726   auto Loc = D->getBody()->getLocStart();
727   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
728     return;
729
730   std::string CoverageMapping;
731   llvm::raw_string_ostream OS(CoverageMapping);
732   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
733                                 CGM.getContext().getSourceManager(),
734                                 CGM.getLangOpts());
735   MappingGen.emitEmptyMapping(D, OS);
736   OS.flush();
737
738   if (CoverageMapping.empty())
739     return;
740
741   setFuncName(Name, Linkage);
742   CGM.getCoverageMapping()->addFunctionMappingRecord(
743       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
744 }
745
746 void CodeGenPGO::computeRegionCounts(const Decl *D) {
747   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
748   ComputeRegionCounts Walker(*StmtCountMap, *this);
749   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
750     Walker.VisitFunctionDecl(FD);
751   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
752     Walker.VisitObjCMethodDecl(MD);
753   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
754     Walker.VisitBlockDecl(BD);
755   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
756     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
757 }
758
759 void
760 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
761                                     llvm::Function *Fn) {
762   if (!haveRegionCounts())
763     return;
764
765   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
766   uint64_t FunctionCount = getRegionCount(0);
767   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
768     // Turn on InlineHint attribute for hot functions.
769     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
770     Fn->addFnAttr(llvm::Attribute::InlineHint);
771   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
772     // Turn on Cold attribute for cold functions.
773     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
774     Fn->addFnAttr(llvm::Attribute::Cold);
775
776   Fn->setEntryCount(FunctionCount);
777 }
778
779 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
780   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
781     return;
782   if (!Builder.GetInsertPoint())
783     return;
784
785   unsigned Counter = (*RegionCounterMap)[S];
786   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
787   Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
788                      {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
789                       Builder.getInt64(FunctionHash),
790                       Builder.getInt32(NumRegionCounters),
791                       Builder.getInt32(Counter)});
792 }
793
794 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
795                                   bool IsInMainFile) {
796   CGM.getPGOStats().addVisited(IsInMainFile);
797   RegionCounts.clear();
798   if (std::error_code EC =
799           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
800     if (EC == llvm::instrprof_error::unknown_function)
801       CGM.getPGOStats().addMissing(IsInMainFile);
802     else if (EC == llvm::instrprof_error::hash_mismatch)
803       CGM.getPGOStats().addMismatched(IsInMainFile);
804     else if (EC == llvm::instrprof_error::malformed)
805       // TODO: Consider a more specific warning for this case.
806       CGM.getPGOStats().addMismatched(IsInMainFile);
807     RegionCounts.clear();
808   }
809 }
810
811 /// \brief Calculate what to divide by to scale weights.
812 ///
813 /// Given the maximum weight, calculate a divisor that will scale all the
814 /// weights to strictly less than UINT32_MAX.
815 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
816   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
817 }
818
819 /// \brief Scale an individual branch weight (and add 1).
820 ///
821 /// Scale a 64-bit weight down to 32-bits using \c Scale.
822 ///
823 /// According to Laplace's Rule of Succession, it is better to compute the
824 /// weight based on the count plus 1, so universally add 1 to the value.
825 ///
826 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
827 /// greater than \c Weight.
828 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
829   assert(Scale && "scale by 0?");
830   uint64_t Scaled = Weight / Scale + 1;
831   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
832   return Scaled;
833 }
834
835 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
836                                                     uint64_t FalseCount) {
837   // Check for empty weights.
838   if (!TrueCount && !FalseCount)
839     return nullptr;
840
841   // Calculate how to scale down to 32-bits.
842   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
843
844   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
845   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
846                                       scaleBranchWeight(FalseCount, Scale));
847 }
848
849 llvm::MDNode *
850 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
851   // We need at least two elements to create meaningful weights.
852   if (Weights.size() < 2)
853     return nullptr;
854
855   // Check for empty weights.
856   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
857   if (MaxWeight == 0)
858     return nullptr;
859
860   // Calculate how to scale down to 32-bits.
861   uint64_t Scale = calculateWeightScale(MaxWeight);
862
863   SmallVector<uint32_t, 16> ScaledWeights;
864   ScaledWeights.reserve(Weights.size());
865   for (uint64_t W : Weights)
866     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
867
868   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
869   return MDHelper.createBranchWeights(ScaledWeights);
870 }
871
872 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
873                                                            uint64_t LoopCount) {
874   if (!PGO.haveRegionCounts())
875     return nullptr;
876   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
877   assert(CondCount.hasValue() && "missing expected loop condition count");
878   if (*CondCount == 0)
879     return nullptr;
880   return createProfileWeights(LoopCount,
881                               std::max(*CondCount, LoopCount) - LoopCount);
882 }