1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Instrumentation-based profile-guided optimization
12 //===----------------------------------------------------------------------===//
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
25 static llvm::cl::opt<bool>
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(false));
30 using namespace clang;
31 using namespace CodeGen;
33 void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36 FuncName = llvm::getPGOFuncName(
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40 // If we're generating a profile, create a variable for the name.
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
47 // Create PGOFuncName meta data.
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
56 // Keep this set to the latest hash version.
57 PGO_HASH_LATEST = PGO_HASH_V2
61 /// Stable hasher for PGO region counters.
63 /// PGOHash produces a stable hash of a given function's control flow.
65 /// Changing the output of this hash will invalidate all previously generated
66 /// profiles -- i.e., don't do it.
68 /// \note When this hash does eventually change (years?), we still need to
69 /// support old hashes. We'll need to pull in the version number from the
70 /// profile data format and use the matching hash function.
74 PGOHashVersion HashVersion;
77 static const int NumBitsPerType = 6;
78 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79 static const unsigned TooBig = 1u << NumBitsPerType;
82 /// Hash values for AST nodes.
84 /// Distinct values for AST nodes that have region counters attached.
86 /// These values must be stable. All new members must be added at the end,
87 /// and no members should be removed. Changing the enumeration value for an
88 /// AST node will affect the hash of every function that contains that node.
89 enum HashType : unsigned char {
96 ObjCForCollectionStmt,
106 BinaryConditionalOperator,
107 // The preceding values are available with PGO_HASH_V1.
125 // The preceding values are available with PGO_HASH_V2.
127 // Keep this last. It's for the static assert that follows.
130 static_assert(LastHashType <= TooBig, "Too many types in HashType");
132 PGOHash(PGOHashVersion HashVersion)
133 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134 void combine(HashType Type);
136 PGOHashVersion getHashVersion() const { return HashVersion; }
138 const int PGOHash::NumBitsPerType;
139 const unsigned PGOHash::NumTypesPerWord;
140 const unsigned PGOHash::TooBig;
142 /// Get the PGO hash version used in the given indexed profile.
143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144 CodeGenModule &CGM) {
145 if (PGOReader->getVersion() <= 4)
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152 using Base = RecursiveASTVisitor<MapRegionCounters>;
154 /// The next counter value to assign.
155 unsigned NextCounter;
156 /// The function hash.
158 /// The map of statements to counters.
159 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
161 MapRegionCounters(PGOHashVersion HashVersion,
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
165 // Blocks and lambdas are handled as separate functions, so we need not
166 // traverse them in the parent context.
167 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168 bool TraverseLambdaExpr(LambdaExpr *LE) {
169 // Traverse the captures, but not the body.
170 for (const auto &C : zip(LE->captures(), LE->capture_inits()))
171 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
174 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
176 bool VisitDecl(const Decl *D) {
177 switch (D->getKind()) {
181 case Decl::CXXMethod:
182 case Decl::CXXConstructor:
183 case Decl::CXXDestructor:
184 case Decl::CXXConversion:
185 case Decl::ObjCMethod:
188 CounterMap[D->getBody()] = NextCounter++;
194 /// If \p S gets a fresh counter, update the counter mappings. Return the
196 PGOHash::HashType updateCounterMappings(Stmt *S) {
197 auto Type = getHashType(PGO_HASH_V1, S);
198 if (Type != PGOHash::None)
199 CounterMap[S] = NextCounter++;
203 /// Include \p S in the function hash.
204 bool VisitStmt(Stmt *S) {
205 auto Type = updateCounterMappings(S);
206 if (Hash.getHashVersion() != PGO_HASH_V1)
207 Type = getHashType(Hash.getHashVersion(), S);
208 if (Type != PGOHash::None)
213 bool TraverseIfStmt(IfStmt *If) {
214 // If we used the V1 hash, use the default traversal.
215 if (Hash.getHashVersion() == PGO_HASH_V1)
216 return Base::TraverseIfStmt(If);
218 // Otherwise, keep track of which branch we're in while traversing.
220 for (Stmt *CS : If->children()) {
223 if (CS == If->getThen())
224 Hash.combine(PGOHash::IfThenBranch);
225 else if (CS == If->getElse())
226 Hash.combine(PGOHash::IfElseBranch);
229 Hash.combine(PGOHash::EndOfScope);
233 // If the statement type \p N is nestable, and its nesting impacts profile
234 // stability, define a custom traversal which tracks the end of the statement
235 // in the hash (provided we're not using the V1 hash).
236 #define DEFINE_NESTABLE_TRAVERSAL(N) \
237 bool Traverse##N(N *S) { \
238 Base::Traverse##N(S); \
239 if (Hash.getHashVersion() != PGO_HASH_V1) \
240 Hash.combine(PGOHash::EndOfScope); \
244 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
245 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
246 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
247 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
248 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
249 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
250 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
252 /// Get version \p HashVersion of the PGO hash for \p S.
253 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
254 switch (S->getStmtClass()) {
257 case Stmt::LabelStmtClass:
258 return PGOHash::LabelStmt;
259 case Stmt::WhileStmtClass:
260 return PGOHash::WhileStmt;
261 case Stmt::DoStmtClass:
262 return PGOHash::DoStmt;
263 case Stmt::ForStmtClass:
264 return PGOHash::ForStmt;
265 case Stmt::CXXForRangeStmtClass:
266 return PGOHash::CXXForRangeStmt;
267 case Stmt::ObjCForCollectionStmtClass:
268 return PGOHash::ObjCForCollectionStmt;
269 case Stmt::SwitchStmtClass:
270 return PGOHash::SwitchStmt;
271 case Stmt::CaseStmtClass:
272 return PGOHash::CaseStmt;
273 case Stmt::DefaultStmtClass:
274 return PGOHash::DefaultStmt;
275 case Stmt::IfStmtClass:
276 return PGOHash::IfStmt;
277 case Stmt::CXXTryStmtClass:
278 return PGOHash::CXXTryStmt;
279 case Stmt::CXXCatchStmtClass:
280 return PGOHash::CXXCatchStmt;
281 case Stmt::ConditionalOperatorClass:
282 return PGOHash::ConditionalOperator;
283 case Stmt::BinaryConditionalOperatorClass:
284 return PGOHash::BinaryConditionalOperator;
285 case Stmt::BinaryOperatorClass: {
286 const BinaryOperator *BO = cast<BinaryOperator>(S);
287 if (BO->getOpcode() == BO_LAnd)
288 return PGOHash::BinaryOperatorLAnd;
289 if (BO->getOpcode() == BO_LOr)
290 return PGOHash::BinaryOperatorLOr;
291 if (HashVersion == PGO_HASH_V2) {
292 switch (BO->getOpcode()) {
296 return PGOHash::BinaryOperatorLT;
298 return PGOHash::BinaryOperatorGT;
300 return PGOHash::BinaryOperatorLE;
302 return PGOHash::BinaryOperatorGE;
304 return PGOHash::BinaryOperatorEQ;
306 return PGOHash::BinaryOperatorNE;
313 if (HashVersion == PGO_HASH_V2) {
314 switch (S->getStmtClass()) {
317 case Stmt::GotoStmtClass:
318 return PGOHash::GotoStmt;
319 case Stmt::IndirectGotoStmtClass:
320 return PGOHash::IndirectGotoStmt;
321 case Stmt::BreakStmtClass:
322 return PGOHash::BreakStmt;
323 case Stmt::ContinueStmtClass:
324 return PGOHash::ContinueStmt;
325 case Stmt::ReturnStmtClass:
326 return PGOHash::ReturnStmt;
327 case Stmt::CXXThrowExprClass:
328 return PGOHash::ThrowExpr;
329 case Stmt::UnaryOperatorClass: {
330 const UnaryOperator *UO = cast<UnaryOperator>(S);
331 if (UO->getOpcode() == UO_LNot)
332 return PGOHash::UnaryOperatorLNot;
338 return PGOHash::None;
342 /// A StmtVisitor that propagates the raw counts through the AST and
343 /// records the count at statements where the value may change.
344 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
348 /// A flag that is set when the current count should be recorded on the
349 /// next statement, such as at the exit of a loop.
350 bool RecordNextStmtCount;
352 /// The count at the current location in the traversal.
353 uint64_t CurrentCount;
355 /// The map of statements to count values.
356 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
358 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
359 struct BreakContinue {
361 uint64_t ContinueCount;
362 BreakContinue() : BreakCount(0), ContinueCount(0) {}
364 SmallVector<BreakContinue, 8> BreakContinueStack;
366 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
368 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
370 void RecordStmtCount(const Stmt *S) {
371 if (RecordNextStmtCount) {
372 CountMap[S] = CurrentCount;
373 RecordNextStmtCount = false;
377 /// Set and return the current count.
378 uint64_t setCount(uint64_t Count) {
379 CurrentCount = Count;
383 void VisitStmt(const Stmt *S) {
385 for (const Stmt *Child : S->children())
390 void VisitFunctionDecl(const FunctionDecl *D) {
391 // Counter tracks entry to the function body.
392 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
393 CountMap[D->getBody()] = BodyCount;
397 // Skip lambda expressions. We visit these as FunctionDecls when we're
398 // generating them and aren't interested in the body when generating a
400 void VisitLambdaExpr(const LambdaExpr *LE) {}
402 void VisitCapturedDecl(const CapturedDecl *D) {
403 // Counter tracks entry to the capture body.
404 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
405 CountMap[D->getBody()] = BodyCount;
409 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
410 // Counter tracks entry to the method body.
411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412 CountMap[D->getBody()] = BodyCount;
416 void VisitBlockDecl(const BlockDecl *D) {
417 // Counter tracks entry to the block body.
418 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
419 CountMap[D->getBody()] = BodyCount;
423 void VisitReturnStmt(const ReturnStmt *S) {
425 if (S->getRetValue())
426 Visit(S->getRetValue());
428 RecordNextStmtCount = true;
431 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
434 Visit(E->getSubExpr());
436 RecordNextStmtCount = true;
439 void VisitGotoStmt(const GotoStmt *S) {
442 RecordNextStmtCount = true;
445 void VisitLabelStmt(const LabelStmt *S) {
446 RecordNextStmtCount = false;
447 // Counter tracks the block following the label.
448 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
449 CountMap[S] = BlockCount;
450 Visit(S->getSubStmt());
453 void VisitBreakStmt(const BreakStmt *S) {
455 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
456 BreakContinueStack.back().BreakCount += CurrentCount;
458 RecordNextStmtCount = true;
461 void VisitContinueStmt(const ContinueStmt *S) {
463 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
464 BreakContinueStack.back().ContinueCount += CurrentCount;
466 RecordNextStmtCount = true;
469 void VisitWhileStmt(const WhileStmt *S) {
471 uint64_t ParentCount = CurrentCount;
473 BreakContinueStack.push_back(BreakContinue());
474 // Visit the body region first so the break/continue adjustments can be
475 // included when visiting the condition.
476 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
477 CountMap[S->getBody()] = CurrentCount;
479 uint64_t BackedgeCount = CurrentCount;
481 // ...then go back and propagate counts through the condition. The count
482 // at the start of the condition is the sum of the incoming edges,
483 // the backedge from the end of the loop body, and the edges from
484 // continue statements.
485 BreakContinue BC = BreakContinueStack.pop_back_val();
487 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
488 CountMap[S->getCond()] = CondCount;
490 setCount(BC.BreakCount + CondCount - BodyCount);
491 RecordNextStmtCount = true;
494 void VisitDoStmt(const DoStmt *S) {
496 uint64_t LoopCount = PGO.getRegionCount(S);
498 BreakContinueStack.push_back(BreakContinue());
499 // The count doesn't include the fallthrough from the parent scope. Add it.
500 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
501 CountMap[S->getBody()] = BodyCount;
503 uint64_t BackedgeCount = CurrentCount;
505 BreakContinue BC = BreakContinueStack.pop_back_val();
506 // The count at the start of the condition is equal to the count at the
507 // end of the body, plus any continues.
508 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
509 CountMap[S->getCond()] = CondCount;
511 setCount(BC.BreakCount + CondCount - LoopCount);
512 RecordNextStmtCount = true;
515 void VisitForStmt(const ForStmt *S) {
520 uint64_t ParentCount = CurrentCount;
522 BreakContinueStack.push_back(BreakContinue());
523 // Visit the body region first. (This is basically the same as a while
524 // loop; see further comments in VisitWhileStmt.)
525 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
526 CountMap[S->getBody()] = BodyCount;
528 uint64_t BackedgeCount = CurrentCount;
529 BreakContinue BC = BreakContinueStack.pop_back_val();
531 // The increment is essentially part of the body but it needs to include
532 // the count for all the continue statements.
534 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
535 CountMap[S->getInc()] = IncCount;
539 // ...then go back and propagate counts through the condition.
541 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
543 CountMap[S->getCond()] = CondCount;
546 setCount(BC.BreakCount + CondCount - BodyCount);
547 RecordNextStmtCount = true;
550 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
554 Visit(S->getLoopVarStmt());
555 Visit(S->getRangeStmt());
556 Visit(S->getBeginStmt());
557 Visit(S->getEndStmt());
559 uint64_t ParentCount = CurrentCount;
560 BreakContinueStack.push_back(BreakContinue());
561 // Visit the body region first. (This is basically the same as a while
562 // loop; see further comments in VisitWhileStmt.)
563 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
564 CountMap[S->getBody()] = BodyCount;
566 uint64_t BackedgeCount = CurrentCount;
567 BreakContinue BC = BreakContinueStack.pop_back_val();
569 // The increment is essentially part of the body but it needs to include
570 // the count for all the continue statements.
571 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
572 CountMap[S->getInc()] = IncCount;
575 // ...then go back and propagate counts through the condition.
577 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
578 CountMap[S->getCond()] = CondCount;
580 setCount(BC.BreakCount + CondCount - BodyCount);
581 RecordNextStmtCount = true;
584 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
586 Visit(S->getElement());
587 uint64_t ParentCount = CurrentCount;
588 BreakContinueStack.push_back(BreakContinue());
589 // Counter tracks the body of the loop.
590 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
591 CountMap[S->getBody()] = BodyCount;
593 uint64_t BackedgeCount = CurrentCount;
594 BreakContinue BC = BreakContinueStack.pop_back_val();
596 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
598 RecordNextStmtCount = true;
601 void VisitSwitchStmt(const SwitchStmt *S) {
607 BreakContinueStack.push_back(BreakContinue());
609 // If the switch is inside a loop, add the continue counts.
610 BreakContinue BC = BreakContinueStack.pop_back_val();
611 if (!BreakContinueStack.empty())
612 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
613 // Counter tracks the exit block of the switch.
614 setCount(PGO.getRegionCount(S));
615 RecordNextStmtCount = true;
618 void VisitSwitchCase(const SwitchCase *S) {
619 RecordNextStmtCount = false;
620 // Counter for this particular case. This counts only jumps from the
621 // switch header and does not include fallthrough from the case before
623 uint64_t CaseCount = PGO.getRegionCount(S);
624 setCount(CurrentCount + CaseCount);
625 // We need the count without fallthrough in the mapping, so it's more useful
626 // for branch probabilities.
627 CountMap[S] = CaseCount;
628 RecordNextStmtCount = true;
629 Visit(S->getSubStmt());
632 void VisitIfStmt(const IfStmt *S) {
634 uint64_t ParentCount = CurrentCount;
639 // Counter tracks the "then" part of an if statement. The count for
640 // the "else" part, if it exists, will be calculated from this counter.
641 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
642 CountMap[S->getThen()] = ThenCount;
644 uint64_t OutCount = CurrentCount;
646 uint64_t ElseCount = ParentCount - ThenCount;
649 CountMap[S->getElse()] = ElseCount;
651 OutCount += CurrentCount;
653 OutCount += ElseCount;
655 RecordNextStmtCount = true;
658 void VisitCXXTryStmt(const CXXTryStmt *S) {
660 Visit(S->getTryBlock());
661 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
662 Visit(S->getHandler(I));
663 // Counter tracks the continuation block of the try statement.
664 setCount(PGO.getRegionCount(S));
665 RecordNextStmtCount = true;
668 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
669 RecordNextStmtCount = false;
670 // Counter tracks the catch statement's handler block.
671 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
672 CountMap[S] = CatchCount;
673 Visit(S->getHandlerBlock());
676 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
678 uint64_t ParentCount = CurrentCount;
681 // Counter tracks the "true" part of a conditional operator. The
682 // count in the "false" part will be calculated from this counter.
683 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
684 CountMap[E->getTrueExpr()] = TrueCount;
685 Visit(E->getTrueExpr());
686 uint64_t OutCount = CurrentCount;
688 uint64_t FalseCount = setCount(ParentCount - TrueCount);
689 CountMap[E->getFalseExpr()] = FalseCount;
690 Visit(E->getFalseExpr());
691 OutCount += CurrentCount;
694 RecordNextStmtCount = true;
697 void VisitBinLAnd(const BinaryOperator *E) {
699 uint64_t ParentCount = CurrentCount;
701 // Counter tracks the right hand side of a logical and operator.
702 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
703 CountMap[E->getRHS()] = RHSCount;
705 setCount(ParentCount + RHSCount - CurrentCount);
706 RecordNextStmtCount = true;
709 void VisitBinLOr(const BinaryOperator *E) {
711 uint64_t ParentCount = CurrentCount;
713 // Counter tracks the right hand side of a logical or operator.
714 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
715 CountMap[E->getRHS()] = RHSCount;
717 setCount(ParentCount + RHSCount - CurrentCount);
718 RecordNextStmtCount = true;
721 } // end anonymous namespace
723 void PGOHash::combine(HashType Type) {
724 // Check that we never combine 0 and only have six bits.
725 assert(Type && "Hash is invalid: unexpected type 0");
726 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
728 // Pass through MD5 if enough work has built up.
729 if (Count && Count % NumTypesPerWord == 0) {
730 using namespace llvm::support;
731 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
732 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
736 // Accumulate the current type.
738 Working = Working << NumBitsPerType | Type;
741 uint64_t PGOHash::finalize() {
742 // Use Working as the hash directly if we never used MD5.
743 if (Count <= NumTypesPerWord)
744 // No need to byte swap here, since none of the math was endian-dependent.
745 // This number will be byte-swapped as required on endianness transitions,
746 // so we will see the same value on the other side.
749 // Check for remaining work in Working.
753 // Finalize the MD5 and return the hash.
754 llvm::MD5::MD5Result Result;
756 using namespace llvm::support;
760 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
761 const Decl *D = GD.getDecl();
765 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
766 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
767 if (!InstrumentRegions && !PGOReader)
771 // Constructors and destructors may be represented by several functions in IR.
772 // If so, instrument only base variant, others are implemented by delegation
773 // to the base one, it would be counted twice otherwise.
774 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
775 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
778 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
779 if (GD.getCtorType() != Ctor_Base &&
780 CodeGenFunction::IsConstructorDelegationValid(CCD))
783 CGM.ClearUnusedCoverageMapping(D);
786 mapRegionCounters(D);
787 if (CGM.getCodeGenOpts().CoverageMapping)
788 emitCounterRegionMapping(D);
790 SourceManager &SM = CGM.getContext().getSourceManager();
791 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
792 computeRegionCounts(D);
793 applyFunctionAttributes(PGOReader, Fn);
797 void CodeGenPGO::mapRegionCounters(const Decl *D) {
798 // Use the latest hash version when inserting instrumentation, but use the
799 // version in the indexed profile if we're reading PGO data.
800 PGOHashVersion HashVersion = PGO_HASH_LATEST;
801 if (auto *PGOReader = CGM.getPGOReader())
802 HashVersion = getPGOHashVersion(PGOReader, CGM);
804 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
805 MapRegionCounters Walker(HashVersion, *RegionCounterMap);
806 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
807 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
808 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
809 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
810 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
811 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
812 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
813 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
814 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
815 NumRegionCounters = Walker.NextCounter;
816 FunctionHash = Walker.Hash.finalize();
819 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
823 // Don't map the functions in system headers.
824 const auto &SM = CGM.getContext().getSourceManager();
825 auto Loc = D->getBody()->getBeginLoc();
826 return SM.isInSystemHeader(Loc);
829 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
830 if (skipRegionMappingForDecl(D))
833 std::string CoverageMapping;
834 llvm::raw_string_ostream OS(CoverageMapping);
835 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
836 CGM.getContext().getSourceManager(),
837 CGM.getLangOpts(), RegionCounterMap.get());
838 MappingGen.emitCounterMapping(D, OS);
841 if (CoverageMapping.empty())
844 CGM.getCoverageMapping()->addFunctionMappingRecord(
845 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
849 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
850 llvm::GlobalValue::LinkageTypes Linkage) {
851 if (skipRegionMappingForDecl(D))
854 std::string CoverageMapping;
855 llvm::raw_string_ostream OS(CoverageMapping);
856 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
857 CGM.getContext().getSourceManager(),
859 MappingGen.emitEmptyMapping(D, OS);
862 if (CoverageMapping.empty())
865 setFuncName(Name, Linkage);
866 CGM.getCoverageMapping()->addFunctionMappingRecord(
867 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
870 void CodeGenPGO::computeRegionCounts(const Decl *D) {
871 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
872 ComputeRegionCounts Walker(*StmtCountMap, *this);
873 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
874 Walker.VisitFunctionDecl(FD);
875 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
876 Walker.VisitObjCMethodDecl(MD);
877 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
878 Walker.VisitBlockDecl(BD);
879 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
880 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
884 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
885 llvm::Function *Fn) {
886 if (!haveRegionCounts())
889 uint64_t FunctionCount = getRegionCount(nullptr);
890 Fn->setEntryCount(FunctionCount);
893 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
894 llvm::Value *StepV) {
895 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
897 if (!Builder.GetInsertBlock())
900 unsigned Counter = (*RegionCounterMap)[S];
901 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
903 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
904 Builder.getInt64(FunctionHash),
905 Builder.getInt32(NumRegionCounters),
906 Builder.getInt32(Counter), StepV};
908 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
909 makeArrayRef(Args, 4));
912 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
916 // This method either inserts a call to the profile run-time during
917 // instrumentation or puts profile data into metadata for PGO use.
918 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
919 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
921 if (!EnableValueProfiling)
924 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
927 if (isa<llvm::Constant>(ValuePtr))
930 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
931 if (InstrumentValueSites && RegionCounterMap) {
932 auto BuilderInsertPoint = Builder.saveIP();
933 Builder.SetInsertPoint(ValueSite);
934 llvm::Value *Args[5] = {
935 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
936 Builder.getInt64(FunctionHash),
937 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
938 Builder.getInt32(ValueKind),
939 Builder.getInt32(NumValueSites[ValueKind]++)
942 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
943 Builder.restoreIP(BuilderInsertPoint);
947 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
948 if (PGOReader && haveRegionCounts()) {
949 // We record the top most called three functions at each call site.
950 // Profile metadata contains "VP" string identifying this metadata
951 // as value profiling data, then a uint32_t value for the value profiling
952 // kind, a uint64_t value for the total number of times the call is
953 // executed, followed by the function hash and execution count (uint64_t)
954 // pairs for each function.
955 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
958 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
959 (llvm::InstrProfValueKind)ValueKind,
960 NumValueSites[ValueKind]);
962 NumValueSites[ValueKind]++;
966 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
968 CGM.getPGOStats().addVisited(IsInMainFile);
969 RegionCounts.clear();
970 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
971 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
972 if (auto E = RecordExpected.takeError()) {
973 auto IPE = llvm::InstrProfError::take(std::move(E));
974 if (IPE == llvm::instrprof_error::unknown_function)
975 CGM.getPGOStats().addMissing(IsInMainFile);
976 else if (IPE == llvm::instrprof_error::hash_mismatch)
977 CGM.getPGOStats().addMismatched(IsInMainFile);
978 else if (IPE == llvm::instrprof_error::malformed)
979 // TODO: Consider a more specific warning for this case.
980 CGM.getPGOStats().addMismatched(IsInMainFile);
984 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
985 RegionCounts = ProfRecord->Counts;
988 /// Calculate what to divide by to scale weights.
990 /// Given the maximum weight, calculate a divisor that will scale all the
991 /// weights to strictly less than UINT32_MAX.
992 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
993 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
996 /// Scale an individual branch weight (and add 1).
998 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1000 /// According to Laplace's Rule of Succession, it is better to compute the
1001 /// weight based on the count plus 1, so universally add 1 to the value.
1003 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1004 /// greater than \c Weight.
1005 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1006 assert(Scale && "scale by 0?");
1007 uint64_t Scaled = Weight / Scale + 1;
1008 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1012 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1013 uint64_t FalseCount) {
1014 // Check for empty weights.
1015 if (!TrueCount && !FalseCount)
1018 // Calculate how to scale down to 32-bits.
1019 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1021 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1022 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1023 scaleBranchWeight(FalseCount, Scale));
1027 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1028 // We need at least two elements to create meaningful weights.
1029 if (Weights.size() < 2)
1032 // Check for empty weights.
1033 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1037 // Calculate how to scale down to 32-bits.
1038 uint64_t Scale = calculateWeightScale(MaxWeight);
1040 SmallVector<uint32_t, 16> ScaledWeights;
1041 ScaledWeights.reserve(Weights.size());
1042 for (uint64_t W : Weights)
1043 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1045 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1046 return MDHelper.createBranchWeights(ScaledWeights);
1049 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1050 uint64_t LoopCount) {
1051 if (!PGO.haveRegionCounts())
1053 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1054 assert(CondCount.hasValue() && "missing expected loop condition count");
1055 if (*CondCount == 0)
1057 return createProfileWeights(LoopCount,
1058 std::max(*CondCount, LoopCount) - LoopCount);