]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp
Merge ^/head r274961 through r276472.
[FreeBSD/FreeBSD.git] / contrib / llvm / tools / clang / lib / CodeGen / CodeGenPGO.cpp
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/MDBuilder.h"
19 #include "llvm/ProfileData/InstrProfReader.h"
20 #include "llvm/Support/Endian.h"
21 #include "llvm/Support/FileSystem.h"
22 #include "llvm/Support/MD5.h"
23
24 using namespace clang;
25 using namespace CodeGen;
26
27 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
28   RawFuncName = Fn->getName();
29
30   // Function names may be prefixed with a binary '1' to indicate
31   // that the backend should not modify the symbols due to any platform
32   // naming convention. Do not include that '1' in the PGO profile name.
33   if (RawFuncName[0] == '\1')
34     RawFuncName = RawFuncName.substr(1);
35
36   if (!Fn->hasLocalLinkage()) {
37     PrefixedFuncName.reset(new std::string(RawFuncName));
38     return;
39   }
40
41   // For local symbols, prepend the main file name to distinguish them.
42   // Do not include the full path in the file name since there's no guarantee
43   // that it will stay the same, e.g., if the files are checked out from
44   // version control in different locations.
45   PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
46   if (PrefixedFuncName->empty())
47     PrefixedFuncName->assign("<unknown>");
48   PrefixedFuncName->append(":");
49   PrefixedFuncName->append(RawFuncName);
50 }
51
52 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
53   return CGM.getModule().getFunction("__llvm_profile_register_functions");
54 }
55
56 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
57   // Don't do this for Darwin.  compiler-rt uses linker magic.
58   if (CGM.getTarget().getTriple().isOSDarwin())
59     return nullptr;
60
61   // Only need to insert this once per module.
62   if (llvm::Function *RegisterF = getRegisterFunc(CGM))
63     return &RegisterF->getEntryBlock();
64
65   // Construct the function.
66   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
67   auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
68   auto *RegisterF = llvm::Function::Create(RegisterFTy,
69                                            llvm::GlobalValue::InternalLinkage,
70                                            "__llvm_profile_register_functions",
71                                            &CGM.getModule());
72   RegisterF->setUnnamedAddr(true);
73   if (CGM.getCodeGenOpts().DisableRedZone)
74     RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
75
76   // Construct and return the entry block.
77   auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
78   CGBuilderTy Builder(BB);
79   Builder.CreateRetVoid();
80   return BB;
81 }
82
83 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
84   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
85   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
86   auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
87   return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
88                                              RuntimeRegisterTy);
89 }
90
91 static bool isMachO(const CodeGenModule &CGM) {
92   return CGM.getTarget().getTriple().isOSBinFormatMachO();
93 }
94
95 static StringRef getCountersSection(const CodeGenModule &CGM) {
96   return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
97 }
98
99 static StringRef getNameSection(const CodeGenModule &CGM) {
100   return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
101 }
102
103 static StringRef getDataSection(const CodeGenModule &CGM) {
104   return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
105 }
106
107 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
108   // Create name variable.
109   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
110   auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
111                                                      false);
112   auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
113                                         true, VarLinkage, VarName,
114                                         getFuncVarName("name"));
115   Name->setSection(getNameSection(CGM));
116   Name->setAlignment(1);
117
118   // Create data variable.
119   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
120   auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
121   auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
122   auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
123   llvm::Type *DataTypes[] = {
124     Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
125   };
126   auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
127   llvm::Constant *DataVals[] = {
128     llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
129     llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
130     llvm::ConstantInt::get(Int64Ty, FunctionHash),
131     llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
132     llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
133   };
134   auto *Data =
135     new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
136                              llvm::ConstantStruct::get(DataTy, DataVals),
137                              getFuncVarName("data"));
138
139   // All the data should be packed into an array in its own section.
140   Data->setSection(getDataSection(CGM));
141   Data->setAlignment(8);
142
143   // Hide all these symbols so that we correctly get a copy for each
144   // executable.  The profile format expects names and counters to be
145   // contiguous, so references into shared objects would be invalid.
146   if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) {
147     Name->setVisibility(llvm::GlobalValue::HiddenVisibility);
148     Data->setVisibility(llvm::GlobalValue::HiddenVisibility);
149     RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility);
150   }
151
152   // Make sure the data doesn't get deleted.
153   CGM.addUsedGlobal(Data);
154   return Data;
155 }
156
157 void CodeGenPGO::emitInstrumentationData() {
158   if (!RegionCounters)
159     return;
160
161   // Build the data.
162   auto *Data = buildDataVar();
163
164   // Register the data.
165   auto *RegisterBB = getOrInsertRegisterBB(CGM);
166   if (!RegisterBB)
167     return;
168   CGBuilderTy Builder(RegisterBB->getTerminator());
169   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
170   Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
171                      Builder.CreateBitCast(Data, VoidPtrTy));
172 }
173
174 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
175   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
176     return nullptr;
177
178   assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
179          "profile initialization already emitted");
180
181   // Get the function to call at initialization.
182   llvm::Constant *RegisterF = getRegisterFunc(CGM);
183   if (!RegisterF)
184     return nullptr;
185
186   // Create the initialization function.
187   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
188   auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
189                                    llvm::GlobalValue::InternalLinkage,
190                                    "__llvm_profile_init", &CGM.getModule());
191   F->setUnnamedAddr(true);
192   F->addFnAttr(llvm::Attribute::NoInline);
193   if (CGM.getCodeGenOpts().DisableRedZone)
194     F->addFnAttr(llvm::Attribute::NoRedZone);
195
196   // Add the basic block and the necessary calls.
197   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
198   Builder.CreateCall(RegisterF);
199   Builder.CreateRetVoid();
200
201   return F;
202 }
203
204 namespace {
205 /// \brief Stable hasher for PGO region counters.
206 ///
207 /// PGOHash produces a stable hash of a given function's control flow.
208 ///
209 /// Changing the output of this hash will invalidate all previously generated
210 /// profiles -- i.e., don't do it.
211 ///
212 /// \note  When this hash does eventually change (years?), we still need to
213 /// support old hashes.  We'll need to pull in the version number from the
214 /// profile data format and use the matching hash function.
215 class PGOHash {
216   uint64_t Working;
217   unsigned Count;
218   llvm::MD5 MD5;
219
220   static const int NumBitsPerType = 6;
221   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
222   static const unsigned TooBig = 1u << NumBitsPerType;
223
224 public:
225   /// \brief Hash values for AST nodes.
226   ///
227   /// Distinct values for AST nodes that have region counters attached.
228   ///
229   /// These values must be stable.  All new members must be added at the end,
230   /// and no members should be removed.  Changing the enumeration value for an
231   /// AST node will affect the hash of every function that contains that node.
232   enum HashType : unsigned char {
233     None = 0,
234     LabelStmt = 1,
235     WhileStmt,
236     DoStmt,
237     ForStmt,
238     CXXForRangeStmt,
239     ObjCForCollectionStmt,
240     SwitchStmt,
241     CaseStmt,
242     DefaultStmt,
243     IfStmt,
244     CXXTryStmt,
245     CXXCatchStmt,
246     ConditionalOperator,
247     BinaryOperatorLAnd,
248     BinaryOperatorLOr,
249     BinaryConditionalOperator,
250
251     // Keep this last.  It's for the static assert that follows.
252     LastHashType
253   };
254   static_assert(LastHashType <= TooBig, "Too many types in HashType");
255
256   // TODO: When this format changes, take in a version number here, and use the
257   // old hash calculation for file formats that used the old hash.
258   PGOHash() : Working(0), Count(0) {}
259   void combine(HashType Type);
260   uint64_t finalize();
261 };
262 const int PGOHash::NumBitsPerType;
263 const unsigned PGOHash::NumTypesPerWord;
264 const unsigned PGOHash::TooBig;
265
266   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
267   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
268     /// The next counter value to assign.
269     unsigned NextCounter;
270     /// The function hash.
271     PGOHash Hash;
272     /// The map of statements to counters.
273     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
274
275     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
276         : NextCounter(0), CounterMap(CounterMap) {}
277
278     // Blocks and lambdas are handled as separate functions, so we need not
279     // traverse them in the parent context.
280     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
281     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
282     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
283
284     bool VisitDecl(const Decl *D) {
285       switch (D->getKind()) {
286       default:
287         break;
288       case Decl::Function:
289       case Decl::CXXMethod:
290       case Decl::CXXConstructor:
291       case Decl::CXXDestructor:
292       case Decl::CXXConversion:
293       case Decl::ObjCMethod:
294       case Decl::Block:
295       case Decl::Captured:
296         CounterMap[D->getBody()] = NextCounter++;
297         break;
298       }
299       return true;
300     }
301
302     bool VisitStmt(const Stmt *S) {
303       auto Type = getHashType(S);
304       if (Type == PGOHash::None)
305         return true;
306
307       CounterMap[S] = NextCounter++;
308       Hash.combine(Type);
309       return true;
310     }
311     PGOHash::HashType getHashType(const Stmt *S) {
312       switch (S->getStmtClass()) {
313       default:
314         break;
315       case Stmt::LabelStmtClass:
316         return PGOHash::LabelStmt;
317       case Stmt::WhileStmtClass:
318         return PGOHash::WhileStmt;
319       case Stmt::DoStmtClass:
320         return PGOHash::DoStmt;
321       case Stmt::ForStmtClass:
322         return PGOHash::ForStmt;
323       case Stmt::CXXForRangeStmtClass:
324         return PGOHash::CXXForRangeStmt;
325       case Stmt::ObjCForCollectionStmtClass:
326         return PGOHash::ObjCForCollectionStmt;
327       case Stmt::SwitchStmtClass:
328         return PGOHash::SwitchStmt;
329       case Stmt::CaseStmtClass:
330         return PGOHash::CaseStmt;
331       case Stmt::DefaultStmtClass:
332         return PGOHash::DefaultStmt;
333       case Stmt::IfStmtClass:
334         return PGOHash::IfStmt;
335       case Stmt::CXXTryStmtClass:
336         return PGOHash::CXXTryStmt;
337       case Stmt::CXXCatchStmtClass:
338         return PGOHash::CXXCatchStmt;
339       case Stmt::ConditionalOperatorClass:
340         return PGOHash::ConditionalOperator;
341       case Stmt::BinaryConditionalOperatorClass:
342         return PGOHash::BinaryConditionalOperator;
343       case Stmt::BinaryOperatorClass: {
344         const BinaryOperator *BO = cast<BinaryOperator>(S);
345         if (BO->getOpcode() == BO_LAnd)
346           return PGOHash::BinaryOperatorLAnd;
347         if (BO->getOpcode() == BO_LOr)
348           return PGOHash::BinaryOperatorLOr;
349         break;
350       }
351       }
352       return PGOHash::None;
353     }
354   };
355
356   /// A StmtVisitor that propagates the raw counts through the AST and
357   /// records the count at statements where the value may change.
358   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
359     /// PGO state.
360     CodeGenPGO &PGO;
361
362     /// A flag that is set when the current count should be recorded on the
363     /// next statement, such as at the exit of a loop.
364     bool RecordNextStmtCount;
365
366     /// The map of statements to count values.
367     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
368
369     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
370     struct BreakContinue {
371       uint64_t BreakCount;
372       uint64_t ContinueCount;
373       BreakContinue() : BreakCount(0), ContinueCount(0) {}
374     };
375     SmallVector<BreakContinue, 8> BreakContinueStack;
376
377     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
378                         CodeGenPGO &PGO)
379         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
380
381     void RecordStmtCount(const Stmt *S) {
382       if (RecordNextStmtCount) {
383         CountMap[S] = PGO.getCurrentRegionCount();
384         RecordNextStmtCount = false;
385       }
386     }
387
388     void VisitStmt(const Stmt *S) {
389       RecordStmtCount(S);
390       for (Stmt::const_child_range I = S->children(); I; ++I) {
391         if (*I)
392          this->Visit(*I);
393       }
394     }
395
396     void VisitFunctionDecl(const FunctionDecl *D) {
397       // Counter tracks entry to the function body.
398       RegionCounter Cnt(PGO, D->getBody());
399       Cnt.beginRegion();
400       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
401       Visit(D->getBody());
402     }
403
404     // Skip lambda expressions. We visit these as FunctionDecls when we're
405     // generating them and aren't interested in the body when generating a
406     // parent context.
407     void VisitLambdaExpr(const LambdaExpr *LE) {}
408
409     void VisitCapturedDecl(const CapturedDecl *D) {
410       // Counter tracks entry to the capture body.
411       RegionCounter Cnt(PGO, D->getBody());
412       Cnt.beginRegion();
413       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
414       Visit(D->getBody());
415     }
416
417     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
418       // Counter tracks entry to the method body.
419       RegionCounter Cnt(PGO, D->getBody());
420       Cnt.beginRegion();
421       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
422       Visit(D->getBody());
423     }
424
425     void VisitBlockDecl(const BlockDecl *D) {
426       // Counter tracks entry to the block body.
427       RegionCounter Cnt(PGO, D->getBody());
428       Cnt.beginRegion();
429       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
430       Visit(D->getBody());
431     }
432
433     void VisitReturnStmt(const ReturnStmt *S) {
434       RecordStmtCount(S);
435       if (S->getRetValue())
436         Visit(S->getRetValue());
437       PGO.setCurrentRegionUnreachable();
438       RecordNextStmtCount = true;
439     }
440
441     void VisitGotoStmt(const GotoStmt *S) {
442       RecordStmtCount(S);
443       PGO.setCurrentRegionUnreachable();
444       RecordNextStmtCount = true;
445     }
446
447     void VisitLabelStmt(const LabelStmt *S) {
448       RecordNextStmtCount = false;
449       // Counter tracks the block following the label.
450       RegionCounter Cnt(PGO, S);
451       Cnt.beginRegion();
452       CountMap[S] = PGO.getCurrentRegionCount();
453       Visit(S->getSubStmt());
454     }
455
456     void VisitBreakStmt(const BreakStmt *S) {
457       RecordStmtCount(S);
458       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
460       PGO.setCurrentRegionUnreachable();
461       RecordNextStmtCount = true;
462     }
463
464     void VisitContinueStmt(const ContinueStmt *S) {
465       RecordStmtCount(S);
466       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
468       PGO.setCurrentRegionUnreachable();
469       RecordNextStmtCount = true;
470     }
471
472     void VisitWhileStmt(const WhileStmt *S) {
473       RecordStmtCount(S);
474       // Counter tracks the body of the loop.
475       RegionCounter Cnt(PGO, S);
476       BreakContinueStack.push_back(BreakContinue());
477       // Visit the body region first so the break/continue adjustments can be
478       // included when visiting the condition.
479       Cnt.beginRegion();
480       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
481       Visit(S->getBody());
482       Cnt.adjustForControlFlow();
483
484       // ...then go back and propagate counts through the condition. The count
485       // at the start of the condition is the sum of the incoming edges,
486       // the backedge from the end of the loop body, and the edges from
487       // continue statements.
488       BreakContinue BC = BreakContinueStack.pop_back_val();
489       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
490                                 Cnt.getAdjustedCount() + BC.ContinueCount);
491       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
492       Visit(S->getCond());
493       Cnt.adjustForControlFlow();
494       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
495       RecordNextStmtCount = true;
496     }
497
498     void VisitDoStmt(const DoStmt *S) {
499       RecordStmtCount(S);
500       // Counter tracks the body of the loop.
501       RegionCounter Cnt(PGO, S);
502       BreakContinueStack.push_back(BreakContinue());
503       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
504       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
505       Visit(S->getBody());
506       Cnt.adjustForControlFlow();
507
508       BreakContinue BC = BreakContinueStack.pop_back_val();
509       // The count at the start of the condition is equal to the count at the
510       // end of the body. The adjusted count does not include either the
511       // fall-through count coming into the loop or the continue count, so add
512       // both of those separately. This is coincidentally the same equation as
513       // with while loops but for different reasons.
514       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
515                                 Cnt.getAdjustedCount() + BC.ContinueCount);
516       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
517       Visit(S->getCond());
518       Cnt.adjustForControlFlow();
519       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
520       RecordNextStmtCount = true;
521     }
522
523     void VisitForStmt(const ForStmt *S) {
524       RecordStmtCount(S);
525       if (S->getInit())
526         Visit(S->getInit());
527       // Counter tracks the body of the loop.
528       RegionCounter Cnt(PGO, S);
529       BreakContinueStack.push_back(BreakContinue());
530       // Visit the body region first. (This is basically the same as a while
531       // loop; see further comments in VisitWhileStmt.)
532       Cnt.beginRegion();
533       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
534       Visit(S->getBody());
535       Cnt.adjustForControlFlow();
536
537       // The increment is essentially part of the body but it needs to include
538       // the count for all the continue statements.
539       if (S->getInc()) {
540         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
541                                   BreakContinueStack.back().ContinueCount);
542         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
543         Visit(S->getInc());
544         Cnt.adjustForControlFlow();
545       }
546
547       BreakContinue BC = BreakContinueStack.pop_back_val();
548
549       // ...then go back and propagate counts through the condition.
550       if (S->getCond()) {
551         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
552                                   Cnt.getAdjustedCount() +
553                                   BC.ContinueCount);
554         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
555         Visit(S->getCond());
556         Cnt.adjustForControlFlow();
557       }
558       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
559       RecordNextStmtCount = true;
560     }
561
562     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
563       RecordStmtCount(S);
564       Visit(S->getRangeStmt());
565       Visit(S->getBeginEndStmt());
566       // Counter tracks the body of the loop.
567       RegionCounter Cnt(PGO, S);
568       BreakContinueStack.push_back(BreakContinue());
569       // Visit the body region first. (This is basically the same as a while
570       // loop; see further comments in VisitWhileStmt.)
571       Cnt.beginRegion();
572       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
573       Visit(S->getLoopVarStmt());
574       Visit(S->getBody());
575       Cnt.adjustForControlFlow();
576
577       // The increment is essentially part of the body but it needs to include
578       // the count for all the continue statements.
579       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
580                                 BreakContinueStack.back().ContinueCount);
581       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
582       Visit(S->getInc());
583       Cnt.adjustForControlFlow();
584
585       BreakContinue BC = BreakContinueStack.pop_back_val();
586
587       // ...then go back and propagate counts through the condition.
588       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
589                                 Cnt.getAdjustedCount() +
590                                 BC.ContinueCount);
591       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
592       Visit(S->getCond());
593       Cnt.adjustForControlFlow();
594       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
595       RecordNextStmtCount = true;
596     }
597
598     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
599       RecordStmtCount(S);
600       Visit(S->getElement());
601       // Counter tracks the body of the loop.
602       RegionCounter Cnt(PGO, S);
603       BreakContinueStack.push_back(BreakContinue());
604       Cnt.beginRegion();
605       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
606       Visit(S->getBody());
607       BreakContinue BC = BreakContinueStack.pop_back_val();
608       Cnt.adjustForControlFlow();
609       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
610       RecordNextStmtCount = true;
611     }
612
613     void VisitSwitchStmt(const SwitchStmt *S) {
614       RecordStmtCount(S);
615       Visit(S->getCond());
616       PGO.setCurrentRegionUnreachable();
617       BreakContinueStack.push_back(BreakContinue());
618       Visit(S->getBody());
619       // If the switch is inside a loop, add the continue counts.
620       BreakContinue BC = BreakContinueStack.pop_back_val();
621       if (!BreakContinueStack.empty())
622         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
623       // Counter tracks the exit block of the switch.
624       RegionCounter ExitCnt(PGO, S);
625       ExitCnt.beginRegion();
626       RecordNextStmtCount = true;
627     }
628
629     void VisitCaseStmt(const CaseStmt *S) {
630       RecordNextStmtCount = false;
631       // Counter for this particular case. This counts only jumps from the
632       // switch header and does not include fallthrough from the case before
633       // this one.
634       RegionCounter Cnt(PGO, S);
635       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
636       CountMap[S] = Cnt.getCount();
637       RecordNextStmtCount = true;
638       Visit(S->getSubStmt());
639     }
640
641     void VisitDefaultStmt(const DefaultStmt *S) {
642       RecordNextStmtCount = false;
643       // Counter for this default case. This does not include fallthrough from
644       // the previous case.
645       RegionCounter Cnt(PGO, S);
646       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
647       CountMap[S] = Cnt.getCount();
648       RecordNextStmtCount = true;
649       Visit(S->getSubStmt());
650     }
651
652     void VisitIfStmt(const IfStmt *S) {
653       RecordStmtCount(S);
654       // Counter tracks the "then" part of an if statement. The count for
655       // the "else" part, if it exists, will be calculated from this counter.
656       RegionCounter Cnt(PGO, S);
657       Visit(S->getCond());
658
659       Cnt.beginRegion();
660       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
661       Visit(S->getThen());
662       Cnt.adjustForControlFlow();
663
664       if (S->getElse()) {
665         Cnt.beginElseRegion();
666         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
667         Visit(S->getElse());
668         Cnt.adjustForControlFlow();
669       }
670       Cnt.applyAdjustmentsToRegion(0);
671       RecordNextStmtCount = true;
672     }
673
674     void VisitCXXTryStmt(const CXXTryStmt *S) {
675       RecordStmtCount(S);
676       Visit(S->getTryBlock());
677       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
678         Visit(S->getHandler(I));
679       // Counter tracks the continuation block of the try statement.
680       RegionCounter Cnt(PGO, S);
681       Cnt.beginRegion();
682       RecordNextStmtCount = true;
683     }
684
685     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
686       RecordNextStmtCount = false;
687       // Counter tracks the catch statement's handler block.
688       RegionCounter Cnt(PGO, S);
689       Cnt.beginRegion();
690       CountMap[S] = PGO.getCurrentRegionCount();
691       Visit(S->getHandlerBlock());
692     }
693
694     void VisitAbstractConditionalOperator(
695         const AbstractConditionalOperator *E) {
696       RecordStmtCount(E);
697       // Counter tracks the "true" part of a conditional operator. The
698       // count in the "false" part will be calculated from this counter.
699       RegionCounter Cnt(PGO, E);
700       Visit(E->getCond());
701
702       Cnt.beginRegion();
703       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
704       Visit(E->getTrueExpr());
705       Cnt.adjustForControlFlow();
706
707       Cnt.beginElseRegion();
708       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
709       Visit(E->getFalseExpr());
710       Cnt.adjustForControlFlow();
711
712       Cnt.applyAdjustmentsToRegion(0);
713       RecordNextStmtCount = true;
714     }
715
716     void VisitBinLAnd(const BinaryOperator *E) {
717       RecordStmtCount(E);
718       // Counter tracks the right hand side of a logical and operator.
719       RegionCounter Cnt(PGO, E);
720       Visit(E->getLHS());
721       Cnt.beginRegion();
722       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
723       Visit(E->getRHS());
724       Cnt.adjustForControlFlow();
725       Cnt.applyAdjustmentsToRegion(0);
726       RecordNextStmtCount = true;
727     }
728
729     void VisitBinLOr(const BinaryOperator *E) {
730       RecordStmtCount(E);
731       // Counter tracks the right hand side of a logical or operator.
732       RegionCounter Cnt(PGO, E);
733       Visit(E->getLHS());
734       Cnt.beginRegion();
735       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
736       Visit(E->getRHS());
737       Cnt.adjustForControlFlow();
738       Cnt.applyAdjustmentsToRegion(0);
739       RecordNextStmtCount = true;
740     }
741   };
742 }
743
744 void PGOHash::combine(HashType Type) {
745   // Check that we never combine 0 and only have six bits.
746   assert(Type && "Hash is invalid: unexpected type 0");
747   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
748
749   // Pass through MD5 if enough work has built up.
750   if (Count && Count % NumTypesPerWord == 0) {
751     using namespace llvm::support;
752     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
753     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
754     Working = 0;
755   }
756
757   // Accumulate the current type.
758   ++Count;
759   Working = Working << NumBitsPerType | Type;
760 }
761
762 uint64_t PGOHash::finalize() {
763   // Use Working as the hash directly if we never used MD5.
764   if (Count <= NumTypesPerWord)
765     // No need to byte swap here, since none of the math was endian-dependent.
766     // This number will be byte-swapped as required on endianness transitions,
767     // so we will see the same value on the other side.
768     return Working;
769
770   // Check for remaining work in Working.
771   if (Working)
772     MD5.update(Working);
773
774   // Finalize the MD5 and return the hash.
775   llvm::MD5::MD5Result Result;
776   MD5.final(Result);
777   using namespace llvm::support;
778   return endian::read<uint64_t, little, unaligned>(Result);
779 }
780
781 static void emitRuntimeHook(CodeGenModule &CGM) {
782   const char *const RuntimeVarName = "__llvm_profile_runtime";
783   const char *const RuntimeUserName = "__llvm_profile_runtime_user";
784   if (CGM.getModule().getGlobalVariable(RuntimeVarName))
785     return;
786
787   // Declare the runtime hook.
788   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
789   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
790   auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
791                                        llvm::GlobalValue::ExternalLinkage,
792                                        nullptr, RuntimeVarName);
793
794   // Make a function that uses it.
795   auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
796                                       llvm::GlobalValue::LinkOnceODRLinkage,
797                                       RuntimeUserName, &CGM.getModule());
798   User->addFnAttr(llvm::Attribute::NoInline);
799   if (CGM.getCodeGenOpts().DisableRedZone)
800     User->addFnAttr(llvm::Attribute::NoRedZone);
801   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
802   auto *Load = Builder.CreateLoad(Var);
803   Builder.CreateRet(Load);
804
805   // Create a use of the function.  Now the definition of the runtime variable
806   // should get pulled in, along with any static initializears.
807   CGM.addUsedGlobal(User);
808 }
809
810 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
811   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
812   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
813   if (!InstrumentRegions && !PGOReader)
814     return;
815   if (D->isImplicit())
816     return;
817   setFuncName(Fn);
818
819   // Set the linkage for variables based on the function linkage.  Usually, we
820   // want to match it, but available_externally and extern_weak both have the
821   // wrong semantics.
822   VarLinkage = Fn->getLinkage();
823   switch (VarLinkage) {
824   case llvm::GlobalValue::ExternalWeakLinkage:
825     VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
826     break;
827   case llvm::GlobalValue::AvailableExternallyLinkage:
828     VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
829     break;
830   default:
831     break;
832   }
833
834   mapRegionCounters(D);
835   if (InstrumentRegions) {
836     emitRuntimeHook(CGM);
837     emitCounterVariables();
838   }
839   if (PGOReader) {
840     SourceManager &SM = CGM.getContext().getSourceManager();
841     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
842     computeRegionCounts(D);
843     applyFunctionAttributes(PGOReader, Fn);
844   }
845 }
846
847 void CodeGenPGO::mapRegionCounters(const Decl *D) {
848   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
849   MapRegionCounters Walker(*RegionCounterMap);
850   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
851     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
852   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
853     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
854   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
855     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
856   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
857     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
858   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
859   NumRegionCounters = Walker.NextCounter;
860   FunctionHash = Walker.Hash.finalize();
861 }
862
863 void CodeGenPGO::computeRegionCounts(const Decl *D) {
864   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
865   ComputeRegionCounts Walker(*StmtCountMap, *this);
866   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
867     Walker.VisitFunctionDecl(FD);
868   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
869     Walker.VisitObjCMethodDecl(MD);
870   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
871     Walker.VisitBlockDecl(BD);
872   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
873     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
874 }
875
876 void
877 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
878                                     llvm::Function *Fn) {
879   if (!haveRegionCounts())
880     return;
881
882   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
883   uint64_t FunctionCount = getRegionCount(0);
884   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
885     // Turn on InlineHint attribute for hot functions.
886     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
887     Fn->addFnAttr(llvm::Attribute::InlineHint);
888   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
889     // Turn on Cold attribute for cold functions.
890     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
891     Fn->addFnAttr(llvm::Attribute::Cold);
892 }
893
894 void CodeGenPGO::emitCounterVariables() {
895   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
896   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
897                                                     NumRegionCounters);
898   RegionCounters =
899     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
900                              llvm::Constant::getNullValue(CounterTy),
901                              getFuncVarName("counters"));
902   RegionCounters->setAlignment(8);
903   RegionCounters->setSection(getCountersSection(CGM));
904 }
905
906 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
907   if (!RegionCounters)
908     return;
909   llvm::Value *Addr =
910     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
911   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
912   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
913   Builder.CreateStore(Count, Addr);
914 }
915
916 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
917                                   bool IsInMainFile) {
918   CGM.getPGOStats().addVisited(IsInMainFile);
919   RegionCounts.reset(new std::vector<uint64_t>);
920   uint64_t Hash;
921   if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) {
922     CGM.getPGOStats().addMissing(IsInMainFile);
923     RegionCounts.reset();
924   } else if (Hash != FunctionHash ||
925              RegionCounts->size() != NumRegionCounters) {
926     CGM.getPGOStats().addMismatched(IsInMainFile);
927     RegionCounts.reset();
928   }
929 }
930
931 void CodeGenPGO::destroyRegionCounters() {
932   RegionCounterMap.reset();
933   StmtCountMap.reset();
934   RegionCounts.reset();
935   RegionCounters = nullptr;
936 }
937
938 /// \brief Calculate what to divide by to scale weights.
939 ///
940 /// Given the maximum weight, calculate a divisor that will scale all the
941 /// weights to strictly less than UINT32_MAX.
942 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
943   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
944 }
945
946 /// \brief Scale an individual branch weight (and add 1).
947 ///
948 /// Scale a 64-bit weight down to 32-bits using \c Scale.
949 ///
950 /// According to Laplace's Rule of Succession, it is better to compute the
951 /// weight based on the count plus 1, so universally add 1 to the value.
952 ///
953 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
954 /// greater than \c Weight.
955 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
956   assert(Scale && "scale by 0?");
957   uint64_t Scaled = Weight / Scale + 1;
958   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
959   return Scaled;
960 }
961
962 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
963                                               uint64_t FalseCount) {
964   // Check for empty weights.
965   if (!TrueCount && !FalseCount)
966     return nullptr;
967
968   // Calculate how to scale down to 32-bits.
969   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
970
971   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
972   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
973                                       scaleBranchWeight(FalseCount, Scale));
974 }
975
976 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
977   // We need at least two elements to create meaningful weights.
978   if (Weights.size() < 2)
979     return nullptr;
980
981   // Check for empty weights.
982   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
983   if (MaxWeight == 0)
984     return nullptr;
985
986   // Calculate how to scale down to 32-bits.
987   uint64_t Scale = calculateWeightScale(MaxWeight);
988
989   SmallVector<uint32_t, 16> ScaledWeights;
990   ScaledWeights.reserve(Weights.size());
991   for (uint64_t W : Weights)
992     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
993
994   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
995   return MDHelper.createBranchWeights(ScaledWeights);
996 }
997
998 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
999                                             RegionCounter &Cnt) {
1000   if (!haveRegionCounts())
1001     return nullptr;
1002   uint64_t LoopCount = Cnt.getCount();
1003   uint64_t CondCount = 0;
1004   bool Found = getStmtCount(Cond, CondCount);
1005   assert(Found && "missing expected loop condition count");
1006   (void)Found;
1007   if (CondCount == 0)
1008     return nullptr;
1009   return createBranchWeights(LoopCount,
1010                              std::max(CondCount, LoopCount) - LoopCount);
1011 }