1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Instrumentation-based profile-guided optimization
12 //===----------------------------------------------------------------------===//
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/MDBuilder.h"
19 #include "llvm/ProfileData/InstrProfReader.h"
20 #include "llvm/Support/Endian.h"
21 #include "llvm/Support/FileSystem.h"
22 #include "llvm/Support/MD5.h"
24 using namespace clang;
25 using namespace CodeGen;
27 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
28 RawFuncName = Fn->getName();
30 // Function names may be prefixed with a binary '1' to indicate
31 // that the backend should not modify the symbols due to any platform
32 // naming convention. Do not include that '1' in the PGO profile name.
33 if (RawFuncName[0] == '\1')
34 RawFuncName = RawFuncName.substr(1);
36 if (!Fn->hasLocalLinkage()) {
37 PrefixedFuncName.reset(new std::string(RawFuncName));
41 // For local symbols, prepend the main file name to distinguish them.
42 // Do not include the full path in the file name since there's no guarantee
43 // that it will stay the same, e.g., if the files are checked out from
44 // version control in different locations.
45 PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
46 if (PrefixedFuncName->empty())
47 PrefixedFuncName->assign("<unknown>");
48 PrefixedFuncName->append(":");
49 PrefixedFuncName->append(RawFuncName);
52 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
53 return CGM.getModule().getFunction("__llvm_profile_register_functions");
56 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
57 // Don't do this for Darwin. compiler-rt uses linker magic.
58 if (CGM.getTarget().getTriple().isOSDarwin())
61 // Only need to insert this once per module.
62 if (llvm::Function *RegisterF = getRegisterFunc(CGM))
63 return &RegisterF->getEntryBlock();
65 // Construct the function.
66 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
67 auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
68 auto *RegisterF = llvm::Function::Create(RegisterFTy,
69 llvm::GlobalValue::InternalLinkage,
70 "__llvm_profile_register_functions",
72 RegisterF->setUnnamedAddr(true);
73 if (CGM.getCodeGenOpts().DisableRedZone)
74 RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
76 // Construct and return the entry block.
77 auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
78 CGBuilderTy Builder(BB);
79 Builder.CreateRetVoid();
83 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
84 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
85 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
86 auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
87 return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
91 static bool isMachO(const CodeGenModule &CGM) {
92 return CGM.getTarget().getTriple().isOSBinFormatMachO();
95 static StringRef getCountersSection(const CodeGenModule &CGM) {
96 return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
99 static StringRef getNameSection(const CodeGenModule &CGM) {
100 return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
103 static StringRef getDataSection(const CodeGenModule &CGM) {
104 return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
107 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
108 // Create name variable.
109 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
110 auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
112 auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
113 true, VarLinkage, VarName,
114 getFuncVarName("name"));
115 Name->setSection(getNameSection(CGM));
116 Name->setAlignment(1);
118 // Create data variable.
119 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
120 auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
121 auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
122 auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
123 llvm::Type *DataTypes[] = {
124 Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
126 auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
127 llvm::Constant *DataVals[] = {
128 llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
129 llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
130 llvm::ConstantInt::get(Int64Ty, FunctionHash),
131 llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
132 llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
135 new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
136 llvm::ConstantStruct::get(DataTy, DataVals),
137 getFuncVarName("data"));
139 // All the data should be packed into an array in its own section.
140 Data->setSection(getDataSection(CGM));
141 Data->setAlignment(8);
143 // Hide all these symbols so that we correctly get a copy for each
144 // executable. The profile format expects names and counters to be
145 // contiguous, so references into shared objects would be invalid.
146 if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) {
147 Name->setVisibility(llvm::GlobalValue::HiddenVisibility);
148 Data->setVisibility(llvm::GlobalValue::HiddenVisibility);
149 RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility);
152 // Make sure the data doesn't get deleted.
153 CGM.addUsedGlobal(Data);
157 void CodeGenPGO::emitInstrumentationData() {
162 auto *Data = buildDataVar();
164 // Register the data.
165 auto *RegisterBB = getOrInsertRegisterBB(CGM);
168 CGBuilderTy Builder(RegisterBB->getTerminator());
169 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
170 Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
171 Builder.CreateBitCast(Data, VoidPtrTy));
174 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
175 if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
178 assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
179 "profile initialization already emitted");
181 // Get the function to call at initialization.
182 llvm::Constant *RegisterF = getRegisterFunc(CGM);
186 // Create the initialization function.
187 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
188 auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
189 llvm::GlobalValue::InternalLinkage,
190 "__llvm_profile_init", &CGM.getModule());
191 F->setUnnamedAddr(true);
192 F->addFnAttr(llvm::Attribute::NoInline);
193 if (CGM.getCodeGenOpts().DisableRedZone)
194 F->addFnAttr(llvm::Attribute::NoRedZone);
196 // Add the basic block and the necessary calls.
197 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
198 Builder.CreateCall(RegisterF);
199 Builder.CreateRetVoid();
205 /// \brief Stable hasher for PGO region counters.
207 /// PGOHash produces a stable hash of a given function's control flow.
209 /// Changing the output of this hash will invalidate all previously generated
210 /// profiles -- i.e., don't do it.
212 /// \note When this hash does eventually change (years?), we still need to
213 /// support old hashes. We'll need to pull in the version number from the
214 /// profile data format and use the matching hash function.
220 static const int NumBitsPerType = 6;
221 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
222 static const unsigned TooBig = 1u << NumBitsPerType;
225 /// \brief Hash values for AST nodes.
227 /// Distinct values for AST nodes that have region counters attached.
229 /// These values must be stable. All new members must be added at the end,
230 /// and no members should be removed. Changing the enumeration value for an
231 /// AST node will affect the hash of every function that contains that node.
232 enum HashType : unsigned char {
239 ObjCForCollectionStmt,
249 BinaryConditionalOperator,
251 // Keep this last. It's for the static assert that follows.
254 static_assert(LastHashType <= TooBig, "Too many types in HashType");
256 // TODO: When this format changes, take in a version number here, and use the
257 // old hash calculation for file formats that used the old hash.
258 PGOHash() : Working(0), Count(0) {}
259 void combine(HashType Type);
262 const int PGOHash::NumBitsPerType;
263 const unsigned PGOHash::NumTypesPerWord;
264 const unsigned PGOHash::TooBig;
266 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
267 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
268 /// The next counter value to assign.
269 unsigned NextCounter;
270 /// The function hash.
272 /// The map of statements to counters.
273 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
275 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
276 : NextCounter(0), CounterMap(CounterMap) {}
278 // Blocks and lambdas are handled as separate functions, so we need not
279 // traverse them in the parent context.
280 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
281 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
282 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
284 bool VisitDecl(const Decl *D) {
285 switch (D->getKind()) {
289 case Decl::CXXMethod:
290 case Decl::CXXConstructor:
291 case Decl::CXXDestructor:
292 case Decl::CXXConversion:
293 case Decl::ObjCMethod:
296 CounterMap[D->getBody()] = NextCounter++;
302 bool VisitStmt(const Stmt *S) {
303 auto Type = getHashType(S);
304 if (Type == PGOHash::None)
307 CounterMap[S] = NextCounter++;
311 PGOHash::HashType getHashType(const Stmt *S) {
312 switch (S->getStmtClass()) {
315 case Stmt::LabelStmtClass:
316 return PGOHash::LabelStmt;
317 case Stmt::WhileStmtClass:
318 return PGOHash::WhileStmt;
319 case Stmt::DoStmtClass:
320 return PGOHash::DoStmt;
321 case Stmt::ForStmtClass:
322 return PGOHash::ForStmt;
323 case Stmt::CXXForRangeStmtClass:
324 return PGOHash::CXXForRangeStmt;
325 case Stmt::ObjCForCollectionStmtClass:
326 return PGOHash::ObjCForCollectionStmt;
327 case Stmt::SwitchStmtClass:
328 return PGOHash::SwitchStmt;
329 case Stmt::CaseStmtClass:
330 return PGOHash::CaseStmt;
331 case Stmt::DefaultStmtClass:
332 return PGOHash::DefaultStmt;
333 case Stmt::IfStmtClass:
334 return PGOHash::IfStmt;
335 case Stmt::CXXTryStmtClass:
336 return PGOHash::CXXTryStmt;
337 case Stmt::CXXCatchStmtClass:
338 return PGOHash::CXXCatchStmt;
339 case Stmt::ConditionalOperatorClass:
340 return PGOHash::ConditionalOperator;
341 case Stmt::BinaryConditionalOperatorClass:
342 return PGOHash::BinaryConditionalOperator;
343 case Stmt::BinaryOperatorClass: {
344 const BinaryOperator *BO = cast<BinaryOperator>(S);
345 if (BO->getOpcode() == BO_LAnd)
346 return PGOHash::BinaryOperatorLAnd;
347 if (BO->getOpcode() == BO_LOr)
348 return PGOHash::BinaryOperatorLOr;
352 return PGOHash::None;
356 /// A StmtVisitor that propagates the raw counts through the AST and
357 /// records the count at statements where the value may change.
358 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
362 /// A flag that is set when the current count should be recorded on the
363 /// next statement, such as at the exit of a loop.
364 bool RecordNextStmtCount;
366 /// The map of statements to count values.
367 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
369 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
370 struct BreakContinue {
372 uint64_t ContinueCount;
373 BreakContinue() : BreakCount(0), ContinueCount(0) {}
375 SmallVector<BreakContinue, 8> BreakContinueStack;
377 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
379 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
381 void RecordStmtCount(const Stmt *S) {
382 if (RecordNextStmtCount) {
383 CountMap[S] = PGO.getCurrentRegionCount();
384 RecordNextStmtCount = false;
388 void VisitStmt(const Stmt *S) {
390 for (Stmt::const_child_range I = S->children(); I; ++I) {
396 void VisitFunctionDecl(const FunctionDecl *D) {
397 // Counter tracks entry to the function body.
398 RegionCounter Cnt(PGO, D->getBody());
400 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
404 // Skip lambda expressions. We visit these as FunctionDecls when we're
405 // generating them and aren't interested in the body when generating a
407 void VisitLambdaExpr(const LambdaExpr *LE) {}
409 void VisitCapturedDecl(const CapturedDecl *D) {
410 // Counter tracks entry to the capture body.
411 RegionCounter Cnt(PGO, D->getBody());
413 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
417 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
418 // Counter tracks entry to the method body.
419 RegionCounter Cnt(PGO, D->getBody());
421 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
425 void VisitBlockDecl(const BlockDecl *D) {
426 // Counter tracks entry to the block body.
427 RegionCounter Cnt(PGO, D->getBody());
429 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
433 void VisitReturnStmt(const ReturnStmt *S) {
435 if (S->getRetValue())
436 Visit(S->getRetValue());
437 PGO.setCurrentRegionUnreachable();
438 RecordNextStmtCount = true;
441 void VisitGotoStmt(const GotoStmt *S) {
443 PGO.setCurrentRegionUnreachable();
444 RecordNextStmtCount = true;
447 void VisitLabelStmt(const LabelStmt *S) {
448 RecordNextStmtCount = false;
449 // Counter tracks the block following the label.
450 RegionCounter Cnt(PGO, S);
452 CountMap[S] = PGO.getCurrentRegionCount();
453 Visit(S->getSubStmt());
456 void VisitBreakStmt(const BreakStmt *S) {
458 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
460 PGO.setCurrentRegionUnreachable();
461 RecordNextStmtCount = true;
464 void VisitContinueStmt(const ContinueStmt *S) {
466 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
468 PGO.setCurrentRegionUnreachable();
469 RecordNextStmtCount = true;
472 void VisitWhileStmt(const WhileStmt *S) {
474 // Counter tracks the body of the loop.
475 RegionCounter Cnt(PGO, S);
476 BreakContinueStack.push_back(BreakContinue());
477 // Visit the body region first so the break/continue adjustments can be
478 // included when visiting the condition.
480 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
482 Cnt.adjustForControlFlow();
484 // ...then go back and propagate counts through the condition. The count
485 // at the start of the condition is the sum of the incoming edges,
486 // the backedge from the end of the loop body, and the edges from
487 // continue statements.
488 BreakContinue BC = BreakContinueStack.pop_back_val();
489 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
490 Cnt.getAdjustedCount() + BC.ContinueCount);
491 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
493 Cnt.adjustForControlFlow();
494 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
495 RecordNextStmtCount = true;
498 void VisitDoStmt(const DoStmt *S) {
500 // Counter tracks the body of the loop.
501 RegionCounter Cnt(PGO, S);
502 BreakContinueStack.push_back(BreakContinue());
503 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
504 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
506 Cnt.adjustForControlFlow();
508 BreakContinue BC = BreakContinueStack.pop_back_val();
509 // The count at the start of the condition is equal to the count at the
510 // end of the body. The adjusted count does not include either the
511 // fall-through count coming into the loop or the continue count, so add
512 // both of those separately. This is coincidentally the same equation as
513 // with while loops but for different reasons.
514 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
515 Cnt.getAdjustedCount() + BC.ContinueCount);
516 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
518 Cnt.adjustForControlFlow();
519 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
520 RecordNextStmtCount = true;
523 void VisitForStmt(const ForStmt *S) {
527 // Counter tracks the body of the loop.
528 RegionCounter Cnt(PGO, S);
529 BreakContinueStack.push_back(BreakContinue());
530 // Visit the body region first. (This is basically the same as a while
531 // loop; see further comments in VisitWhileStmt.)
533 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
535 Cnt.adjustForControlFlow();
537 // The increment is essentially part of the body but it needs to include
538 // the count for all the continue statements.
540 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
541 BreakContinueStack.back().ContinueCount);
542 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
544 Cnt.adjustForControlFlow();
547 BreakContinue BC = BreakContinueStack.pop_back_val();
549 // ...then go back and propagate counts through the condition.
551 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
552 Cnt.getAdjustedCount() +
554 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
556 Cnt.adjustForControlFlow();
558 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
559 RecordNextStmtCount = true;
562 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
564 Visit(S->getRangeStmt());
565 Visit(S->getBeginEndStmt());
566 // Counter tracks the body of the loop.
567 RegionCounter Cnt(PGO, S);
568 BreakContinueStack.push_back(BreakContinue());
569 // Visit the body region first. (This is basically the same as a while
570 // loop; see further comments in VisitWhileStmt.)
572 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
573 Visit(S->getLoopVarStmt());
575 Cnt.adjustForControlFlow();
577 // The increment is essentially part of the body but it needs to include
578 // the count for all the continue statements.
579 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
580 BreakContinueStack.back().ContinueCount);
581 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
583 Cnt.adjustForControlFlow();
585 BreakContinue BC = BreakContinueStack.pop_back_val();
587 // ...then go back and propagate counts through the condition.
588 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
589 Cnt.getAdjustedCount() +
591 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
593 Cnt.adjustForControlFlow();
594 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
595 RecordNextStmtCount = true;
598 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
600 Visit(S->getElement());
601 // Counter tracks the body of the loop.
602 RegionCounter Cnt(PGO, S);
603 BreakContinueStack.push_back(BreakContinue());
605 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
607 BreakContinue BC = BreakContinueStack.pop_back_val();
608 Cnt.adjustForControlFlow();
609 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
610 RecordNextStmtCount = true;
613 void VisitSwitchStmt(const SwitchStmt *S) {
616 PGO.setCurrentRegionUnreachable();
617 BreakContinueStack.push_back(BreakContinue());
619 // If the switch is inside a loop, add the continue counts.
620 BreakContinue BC = BreakContinueStack.pop_back_val();
621 if (!BreakContinueStack.empty())
622 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
623 // Counter tracks the exit block of the switch.
624 RegionCounter ExitCnt(PGO, S);
625 ExitCnt.beginRegion();
626 RecordNextStmtCount = true;
629 void VisitCaseStmt(const CaseStmt *S) {
630 RecordNextStmtCount = false;
631 // Counter for this particular case. This counts only jumps from the
632 // switch header and does not include fallthrough from the case before
634 RegionCounter Cnt(PGO, S);
635 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
636 CountMap[S] = Cnt.getCount();
637 RecordNextStmtCount = true;
638 Visit(S->getSubStmt());
641 void VisitDefaultStmt(const DefaultStmt *S) {
642 RecordNextStmtCount = false;
643 // Counter for this default case. This does not include fallthrough from
644 // the previous case.
645 RegionCounter Cnt(PGO, S);
646 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
647 CountMap[S] = Cnt.getCount();
648 RecordNextStmtCount = true;
649 Visit(S->getSubStmt());
652 void VisitIfStmt(const IfStmt *S) {
654 // Counter tracks the "then" part of an if statement. The count for
655 // the "else" part, if it exists, will be calculated from this counter.
656 RegionCounter Cnt(PGO, S);
660 CountMap[S->getThen()] = PGO.getCurrentRegionCount();
662 Cnt.adjustForControlFlow();
665 Cnt.beginElseRegion();
666 CountMap[S->getElse()] = PGO.getCurrentRegionCount();
668 Cnt.adjustForControlFlow();
670 Cnt.applyAdjustmentsToRegion(0);
671 RecordNextStmtCount = true;
674 void VisitCXXTryStmt(const CXXTryStmt *S) {
676 Visit(S->getTryBlock());
677 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
678 Visit(S->getHandler(I));
679 // Counter tracks the continuation block of the try statement.
680 RegionCounter Cnt(PGO, S);
682 RecordNextStmtCount = true;
685 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
686 RecordNextStmtCount = false;
687 // Counter tracks the catch statement's handler block.
688 RegionCounter Cnt(PGO, S);
690 CountMap[S] = PGO.getCurrentRegionCount();
691 Visit(S->getHandlerBlock());
694 void VisitAbstractConditionalOperator(
695 const AbstractConditionalOperator *E) {
697 // Counter tracks the "true" part of a conditional operator. The
698 // count in the "false" part will be calculated from this counter.
699 RegionCounter Cnt(PGO, E);
703 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
704 Visit(E->getTrueExpr());
705 Cnt.adjustForControlFlow();
707 Cnt.beginElseRegion();
708 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
709 Visit(E->getFalseExpr());
710 Cnt.adjustForControlFlow();
712 Cnt.applyAdjustmentsToRegion(0);
713 RecordNextStmtCount = true;
716 void VisitBinLAnd(const BinaryOperator *E) {
718 // Counter tracks the right hand side of a logical and operator.
719 RegionCounter Cnt(PGO, E);
722 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
724 Cnt.adjustForControlFlow();
725 Cnt.applyAdjustmentsToRegion(0);
726 RecordNextStmtCount = true;
729 void VisitBinLOr(const BinaryOperator *E) {
731 // Counter tracks the right hand side of a logical or operator.
732 RegionCounter Cnt(PGO, E);
735 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
737 Cnt.adjustForControlFlow();
738 Cnt.applyAdjustmentsToRegion(0);
739 RecordNextStmtCount = true;
744 void PGOHash::combine(HashType Type) {
745 // Check that we never combine 0 and only have six bits.
746 assert(Type && "Hash is invalid: unexpected type 0");
747 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
749 // Pass through MD5 if enough work has built up.
750 if (Count && Count % NumTypesPerWord == 0) {
751 using namespace llvm::support;
752 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
753 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
757 // Accumulate the current type.
759 Working = Working << NumBitsPerType | Type;
762 uint64_t PGOHash::finalize() {
763 // Use Working as the hash directly if we never used MD5.
764 if (Count <= NumTypesPerWord)
765 // No need to byte swap here, since none of the math was endian-dependent.
766 // This number will be byte-swapped as required on endianness transitions,
767 // so we will see the same value on the other side.
770 // Check for remaining work in Working.
774 // Finalize the MD5 and return the hash.
775 llvm::MD5::MD5Result Result;
777 using namespace llvm::support;
778 return endian::read<uint64_t, little, unaligned>(Result);
781 static void emitRuntimeHook(CodeGenModule &CGM) {
782 const char *const RuntimeVarName = "__llvm_profile_runtime";
783 const char *const RuntimeUserName = "__llvm_profile_runtime_user";
784 if (CGM.getModule().getGlobalVariable(RuntimeVarName))
787 // Declare the runtime hook.
788 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
789 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
790 auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
791 llvm::GlobalValue::ExternalLinkage,
792 nullptr, RuntimeVarName);
794 // Make a function that uses it.
795 auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
796 llvm::GlobalValue::LinkOnceODRLinkage,
797 RuntimeUserName, &CGM.getModule());
798 User->addFnAttr(llvm::Attribute::NoInline);
799 if (CGM.getCodeGenOpts().DisableRedZone)
800 User->addFnAttr(llvm::Attribute::NoRedZone);
801 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
802 auto *Load = Builder.CreateLoad(Var);
803 Builder.CreateRet(Load);
805 // Create a use of the function. Now the definition of the runtime variable
806 // should get pulled in, along with any static initializears.
807 CGM.addUsedGlobal(User);
810 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
811 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
812 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
813 if (!InstrumentRegions && !PGOReader)
819 // Set the linkage for variables based on the function linkage. Usually, we
820 // want to match it, but available_externally and extern_weak both have the
822 VarLinkage = Fn->getLinkage();
823 switch (VarLinkage) {
824 case llvm::GlobalValue::ExternalWeakLinkage:
825 VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
827 case llvm::GlobalValue::AvailableExternallyLinkage:
828 VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
834 mapRegionCounters(D);
835 if (InstrumentRegions) {
836 emitRuntimeHook(CGM);
837 emitCounterVariables();
840 SourceManager &SM = CGM.getContext().getSourceManager();
841 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
842 computeRegionCounts(D);
843 applyFunctionAttributes(PGOReader, Fn);
847 void CodeGenPGO::mapRegionCounters(const Decl *D) {
848 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
849 MapRegionCounters Walker(*RegionCounterMap);
850 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
851 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
852 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
853 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
854 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
855 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
856 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
857 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
858 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
859 NumRegionCounters = Walker.NextCounter;
860 FunctionHash = Walker.Hash.finalize();
863 void CodeGenPGO::computeRegionCounts(const Decl *D) {
864 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
865 ComputeRegionCounts Walker(*StmtCountMap, *this);
866 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
867 Walker.VisitFunctionDecl(FD);
868 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
869 Walker.VisitObjCMethodDecl(MD);
870 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
871 Walker.VisitBlockDecl(BD);
872 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
873 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
877 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
878 llvm::Function *Fn) {
879 if (!haveRegionCounts())
882 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
883 uint64_t FunctionCount = getRegionCount(0);
884 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
885 // Turn on InlineHint attribute for hot functions.
886 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
887 Fn->addFnAttr(llvm::Attribute::InlineHint);
888 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
889 // Turn on Cold attribute for cold functions.
890 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
891 Fn->addFnAttr(llvm::Attribute::Cold);
894 void CodeGenPGO::emitCounterVariables() {
895 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
896 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
899 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
900 llvm::Constant::getNullValue(CounterTy),
901 getFuncVarName("counters"));
902 RegionCounters->setAlignment(8);
903 RegionCounters->setSection(getCountersSection(CGM));
906 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
910 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
911 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
912 Count = Builder.CreateAdd(Count, Builder.getInt64(1));
913 Builder.CreateStore(Count, Addr);
916 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
918 CGM.getPGOStats().addVisited(IsInMainFile);
919 RegionCounts.reset(new std::vector<uint64_t>);
921 if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) {
922 CGM.getPGOStats().addMissing(IsInMainFile);
923 RegionCounts.reset();
924 } else if (Hash != FunctionHash ||
925 RegionCounts->size() != NumRegionCounters) {
926 CGM.getPGOStats().addMismatched(IsInMainFile);
927 RegionCounts.reset();
931 void CodeGenPGO::destroyRegionCounters() {
932 RegionCounterMap.reset();
933 StmtCountMap.reset();
934 RegionCounts.reset();
935 RegionCounters = nullptr;
938 /// \brief Calculate what to divide by to scale weights.
940 /// Given the maximum weight, calculate a divisor that will scale all the
941 /// weights to strictly less than UINT32_MAX.
942 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
943 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
946 /// \brief Scale an individual branch weight (and add 1).
948 /// Scale a 64-bit weight down to 32-bits using \c Scale.
950 /// According to Laplace's Rule of Succession, it is better to compute the
951 /// weight based on the count plus 1, so universally add 1 to the value.
953 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
954 /// greater than \c Weight.
955 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
956 assert(Scale && "scale by 0?");
957 uint64_t Scaled = Weight / Scale + 1;
958 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
962 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
963 uint64_t FalseCount) {
964 // Check for empty weights.
965 if (!TrueCount && !FalseCount)
968 // Calculate how to scale down to 32-bits.
969 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
971 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
972 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
973 scaleBranchWeight(FalseCount, Scale));
976 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
977 // We need at least two elements to create meaningful weights.
978 if (Weights.size() < 2)
981 // Check for empty weights.
982 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
986 // Calculate how to scale down to 32-bits.
987 uint64_t Scale = calculateWeightScale(MaxWeight);
989 SmallVector<uint32_t, 16> ScaledWeights;
990 ScaledWeights.reserve(Weights.size());
991 for (uint64_t W : Weights)
992 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
994 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
995 return MDHelper.createBranchWeights(ScaledWeights);
998 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
999 RegionCounter &Cnt) {
1000 if (!haveRegionCounts())
1002 uint64_t LoopCount = Cnt.getCount();
1003 uint64_t CondCount = 0;
1004 bool Found = getStmtCount(Cond, CondCount);
1005 assert(Found && "missing expected loop condition count");
1009 return createBranchWeights(LoopCount,
1010 std::max(CondCount, LoopCount) - LoopCount);