1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Instrumentation-based profile-guided optimization
12 //===----------------------------------------------------------------------===//
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
25 static llvm::cl::opt<bool>
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(false));
30 using namespace clang;
31 using namespace CodeGen;
33 void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36 FuncName = llvm::getPGOFuncName(
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40 // If we're generating a profile, create a variable for the name.
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
47 // Create PGOFuncName meta data.
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
56 // Keep this set to the latest hash version.
57 PGO_HASH_LATEST = PGO_HASH_V2
61 /// \brief Stable hasher for PGO region counters.
63 /// PGOHash produces a stable hash of a given function's control flow.
65 /// Changing the output of this hash will invalidate all previously generated
66 /// profiles -- i.e., don't do it.
68 /// \note When this hash does eventually change (years?), we still need to
69 /// support old hashes. We'll need to pull in the version number from the
70 /// profile data format and use the matching hash function.
74 PGOHashVersion HashVersion;
77 static const int NumBitsPerType = 6;
78 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79 static const unsigned TooBig = 1u << NumBitsPerType;
82 /// \brief Hash values for AST nodes.
84 /// Distinct values for AST nodes that have region counters attached.
86 /// These values must be stable. All new members must be added at the end,
87 /// and no members should be removed. Changing the enumeration value for an
88 /// AST node will affect the hash of every function that contains that node.
89 enum HashType : unsigned char {
96 ObjCForCollectionStmt,
106 BinaryConditionalOperator,
107 // The preceding values are available with PGO_HASH_V1.
125 // The preceding values are available with PGO_HASH_V2.
127 // Keep this last. It's for the static assert that follows.
130 static_assert(LastHashType <= TooBig, "Too many types in HashType");
132 PGOHash(PGOHashVersion HashVersion)
133 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134 void combine(HashType Type);
136 PGOHashVersion getHashVersion() const { return HashVersion; }
138 const int PGOHash::NumBitsPerType;
139 const unsigned PGOHash::NumTypesPerWord;
140 const unsigned PGOHash::TooBig;
142 /// Get the PGO hash version used in the given indexed profile.
143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144 CodeGenModule &CGM) {
145 if (PGOReader->getVersion() <= 4)
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152 using Base = RecursiveASTVisitor<MapRegionCounters>;
154 /// The next counter value to assign.
155 unsigned NextCounter;
156 /// The function hash.
158 /// The map of statements to counters.
159 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
161 MapRegionCounters(PGOHashVersion HashVersion,
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
165 // Blocks and lambdas are handled as separate functions, so we need not
166 // traverse them in the parent context.
167 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
169 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
171 bool VisitDecl(const Decl *D) {
172 switch (D->getKind()) {
176 case Decl::CXXMethod:
177 case Decl::CXXConstructor:
178 case Decl::CXXDestructor:
179 case Decl::CXXConversion:
180 case Decl::ObjCMethod:
183 CounterMap[D->getBody()] = NextCounter++;
189 /// If \p S gets a fresh counter, update the counter mappings. Return the
191 PGOHash::HashType updateCounterMappings(Stmt *S) {
192 auto Type = getHashType(PGO_HASH_V1, S);
193 if (Type != PGOHash::None)
194 CounterMap[S] = NextCounter++;
198 /// Include \p S in the function hash.
199 bool VisitStmt(Stmt *S) {
200 auto Type = updateCounterMappings(S);
201 if (Hash.getHashVersion() != PGO_HASH_V1)
202 Type = getHashType(Hash.getHashVersion(), S);
203 if (Type != PGOHash::None)
208 bool TraverseIfStmt(IfStmt *If) {
209 // If we used the V1 hash, use the default traversal.
210 if (Hash.getHashVersion() == PGO_HASH_V1)
211 return Base::TraverseIfStmt(If);
213 // Otherwise, keep track of which branch we're in while traversing.
215 for (Stmt *CS : If->children()) {
218 if (CS == If->getThen())
219 Hash.combine(PGOHash::IfThenBranch);
220 else if (CS == If->getElse())
221 Hash.combine(PGOHash::IfElseBranch);
224 Hash.combine(PGOHash::EndOfScope);
228 // If the statement type \p N is nestable, and its nesting impacts profile
229 // stability, define a custom traversal which tracks the end of the statement
230 // in the hash (provided we're not using the V1 hash).
231 #define DEFINE_NESTABLE_TRAVERSAL(N) \
232 bool Traverse##N(N *S) { \
233 Base::Traverse##N(S); \
234 if (Hash.getHashVersion() != PGO_HASH_V1) \
235 Hash.combine(PGOHash::EndOfScope); \
239 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
240 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
241 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
242 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
243 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
244 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
245 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
247 /// Get version \p HashVersion of the PGO hash for \p S.
248 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
249 switch (S->getStmtClass()) {
252 case Stmt::LabelStmtClass:
253 return PGOHash::LabelStmt;
254 case Stmt::WhileStmtClass:
255 return PGOHash::WhileStmt;
256 case Stmt::DoStmtClass:
257 return PGOHash::DoStmt;
258 case Stmt::ForStmtClass:
259 return PGOHash::ForStmt;
260 case Stmt::CXXForRangeStmtClass:
261 return PGOHash::CXXForRangeStmt;
262 case Stmt::ObjCForCollectionStmtClass:
263 return PGOHash::ObjCForCollectionStmt;
264 case Stmt::SwitchStmtClass:
265 return PGOHash::SwitchStmt;
266 case Stmt::CaseStmtClass:
267 return PGOHash::CaseStmt;
268 case Stmt::DefaultStmtClass:
269 return PGOHash::DefaultStmt;
270 case Stmt::IfStmtClass:
271 return PGOHash::IfStmt;
272 case Stmt::CXXTryStmtClass:
273 return PGOHash::CXXTryStmt;
274 case Stmt::CXXCatchStmtClass:
275 return PGOHash::CXXCatchStmt;
276 case Stmt::ConditionalOperatorClass:
277 return PGOHash::ConditionalOperator;
278 case Stmt::BinaryConditionalOperatorClass:
279 return PGOHash::BinaryConditionalOperator;
280 case Stmt::BinaryOperatorClass: {
281 const BinaryOperator *BO = cast<BinaryOperator>(S);
282 if (BO->getOpcode() == BO_LAnd)
283 return PGOHash::BinaryOperatorLAnd;
284 if (BO->getOpcode() == BO_LOr)
285 return PGOHash::BinaryOperatorLOr;
286 if (HashVersion == PGO_HASH_V2) {
287 switch (BO->getOpcode()) {
291 return PGOHash::BinaryOperatorLT;
293 return PGOHash::BinaryOperatorGT;
295 return PGOHash::BinaryOperatorLE;
297 return PGOHash::BinaryOperatorGE;
299 return PGOHash::BinaryOperatorEQ;
301 return PGOHash::BinaryOperatorNE;
308 if (HashVersion == PGO_HASH_V2) {
309 switch (S->getStmtClass()) {
312 case Stmt::GotoStmtClass:
313 return PGOHash::GotoStmt;
314 case Stmt::IndirectGotoStmtClass:
315 return PGOHash::IndirectGotoStmt;
316 case Stmt::BreakStmtClass:
317 return PGOHash::BreakStmt;
318 case Stmt::ContinueStmtClass:
319 return PGOHash::ContinueStmt;
320 case Stmt::ReturnStmtClass:
321 return PGOHash::ReturnStmt;
322 case Stmt::CXXThrowExprClass:
323 return PGOHash::ThrowExpr;
324 case Stmt::UnaryOperatorClass: {
325 const UnaryOperator *UO = cast<UnaryOperator>(S);
326 if (UO->getOpcode() == UO_LNot)
327 return PGOHash::UnaryOperatorLNot;
333 return PGOHash::None;
337 /// A StmtVisitor that propagates the raw counts through the AST and
338 /// records the count at statements where the value may change.
339 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
343 /// A flag that is set when the current count should be recorded on the
344 /// next statement, such as at the exit of a loop.
345 bool RecordNextStmtCount;
347 /// The count at the current location in the traversal.
348 uint64_t CurrentCount;
350 /// The map of statements to count values.
351 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
353 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
354 struct BreakContinue {
356 uint64_t ContinueCount;
357 BreakContinue() : BreakCount(0), ContinueCount(0) {}
359 SmallVector<BreakContinue, 8> BreakContinueStack;
361 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
363 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
365 void RecordStmtCount(const Stmt *S) {
366 if (RecordNextStmtCount) {
367 CountMap[S] = CurrentCount;
368 RecordNextStmtCount = false;
372 /// Set and return the current count.
373 uint64_t setCount(uint64_t Count) {
374 CurrentCount = Count;
378 void VisitStmt(const Stmt *S) {
380 for (const Stmt *Child : S->children())
385 void VisitFunctionDecl(const FunctionDecl *D) {
386 // Counter tracks entry to the function body.
387 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
388 CountMap[D->getBody()] = BodyCount;
392 // Skip lambda expressions. We visit these as FunctionDecls when we're
393 // generating them and aren't interested in the body when generating a
395 void VisitLambdaExpr(const LambdaExpr *LE) {}
397 void VisitCapturedDecl(const CapturedDecl *D) {
398 // Counter tracks entry to the capture body.
399 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
400 CountMap[D->getBody()] = BodyCount;
404 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
405 // Counter tracks entry to the method body.
406 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
407 CountMap[D->getBody()] = BodyCount;
411 void VisitBlockDecl(const BlockDecl *D) {
412 // Counter tracks entry to the block body.
413 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
414 CountMap[D->getBody()] = BodyCount;
418 void VisitReturnStmt(const ReturnStmt *S) {
420 if (S->getRetValue())
421 Visit(S->getRetValue());
423 RecordNextStmtCount = true;
426 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
429 Visit(E->getSubExpr());
431 RecordNextStmtCount = true;
434 void VisitGotoStmt(const GotoStmt *S) {
437 RecordNextStmtCount = true;
440 void VisitLabelStmt(const LabelStmt *S) {
441 RecordNextStmtCount = false;
442 // Counter tracks the block following the label.
443 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
444 CountMap[S] = BlockCount;
445 Visit(S->getSubStmt());
448 void VisitBreakStmt(const BreakStmt *S) {
450 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
451 BreakContinueStack.back().BreakCount += CurrentCount;
453 RecordNextStmtCount = true;
456 void VisitContinueStmt(const ContinueStmt *S) {
458 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
459 BreakContinueStack.back().ContinueCount += CurrentCount;
461 RecordNextStmtCount = true;
464 void VisitWhileStmt(const WhileStmt *S) {
466 uint64_t ParentCount = CurrentCount;
468 BreakContinueStack.push_back(BreakContinue());
469 // Visit the body region first so the break/continue adjustments can be
470 // included when visiting the condition.
471 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
472 CountMap[S->getBody()] = CurrentCount;
474 uint64_t BackedgeCount = CurrentCount;
476 // ...then go back and propagate counts through the condition. The count
477 // at the start of the condition is the sum of the incoming edges,
478 // the backedge from the end of the loop body, and the edges from
479 // continue statements.
480 BreakContinue BC = BreakContinueStack.pop_back_val();
482 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
483 CountMap[S->getCond()] = CondCount;
485 setCount(BC.BreakCount + CondCount - BodyCount);
486 RecordNextStmtCount = true;
489 void VisitDoStmt(const DoStmt *S) {
491 uint64_t LoopCount = PGO.getRegionCount(S);
493 BreakContinueStack.push_back(BreakContinue());
494 // The count doesn't include the fallthrough from the parent scope. Add it.
495 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
496 CountMap[S->getBody()] = BodyCount;
498 uint64_t BackedgeCount = CurrentCount;
500 BreakContinue BC = BreakContinueStack.pop_back_val();
501 // The count at the start of the condition is equal to the count at the
502 // end of the body, plus any continues.
503 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
504 CountMap[S->getCond()] = CondCount;
506 setCount(BC.BreakCount + CondCount - LoopCount);
507 RecordNextStmtCount = true;
510 void VisitForStmt(const ForStmt *S) {
515 uint64_t ParentCount = CurrentCount;
517 BreakContinueStack.push_back(BreakContinue());
518 // Visit the body region first. (This is basically the same as a while
519 // loop; see further comments in VisitWhileStmt.)
520 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
521 CountMap[S->getBody()] = BodyCount;
523 uint64_t BackedgeCount = CurrentCount;
524 BreakContinue BC = BreakContinueStack.pop_back_val();
526 // The increment is essentially part of the body but it needs to include
527 // the count for all the continue statements.
529 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
530 CountMap[S->getInc()] = IncCount;
534 // ...then go back and propagate counts through the condition.
536 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
538 CountMap[S->getCond()] = CondCount;
541 setCount(BC.BreakCount + CondCount - BodyCount);
542 RecordNextStmtCount = true;
545 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
547 Visit(S->getLoopVarStmt());
548 Visit(S->getRangeStmt());
549 Visit(S->getBeginStmt());
550 Visit(S->getEndStmt());
552 uint64_t ParentCount = CurrentCount;
553 BreakContinueStack.push_back(BreakContinue());
554 // Visit the body region first. (This is basically the same as a while
555 // loop; see further comments in VisitWhileStmt.)
556 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
557 CountMap[S->getBody()] = BodyCount;
559 uint64_t BackedgeCount = CurrentCount;
560 BreakContinue BC = BreakContinueStack.pop_back_val();
562 // The increment is essentially part of the body but it needs to include
563 // the count for all the continue statements.
564 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
565 CountMap[S->getInc()] = IncCount;
568 // ...then go back and propagate counts through the condition.
570 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
571 CountMap[S->getCond()] = CondCount;
573 setCount(BC.BreakCount + CondCount - BodyCount);
574 RecordNextStmtCount = true;
577 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
579 Visit(S->getElement());
580 uint64_t ParentCount = CurrentCount;
581 BreakContinueStack.push_back(BreakContinue());
582 // Counter tracks the body of the loop.
583 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
584 CountMap[S->getBody()] = BodyCount;
586 uint64_t BackedgeCount = CurrentCount;
587 BreakContinue BC = BreakContinueStack.pop_back_val();
589 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
591 RecordNextStmtCount = true;
594 void VisitSwitchStmt(const SwitchStmt *S) {
600 BreakContinueStack.push_back(BreakContinue());
602 // If the switch is inside a loop, add the continue counts.
603 BreakContinue BC = BreakContinueStack.pop_back_val();
604 if (!BreakContinueStack.empty())
605 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
606 // Counter tracks the exit block of the switch.
607 setCount(PGO.getRegionCount(S));
608 RecordNextStmtCount = true;
611 void VisitSwitchCase(const SwitchCase *S) {
612 RecordNextStmtCount = false;
613 // Counter for this particular case. This counts only jumps from the
614 // switch header and does not include fallthrough from the case before
616 uint64_t CaseCount = PGO.getRegionCount(S);
617 setCount(CurrentCount + CaseCount);
618 // We need the count without fallthrough in the mapping, so it's more useful
619 // for branch probabilities.
620 CountMap[S] = CaseCount;
621 RecordNextStmtCount = true;
622 Visit(S->getSubStmt());
625 void VisitIfStmt(const IfStmt *S) {
627 uint64_t ParentCount = CurrentCount;
632 // Counter tracks the "then" part of an if statement. The count for
633 // the "else" part, if it exists, will be calculated from this counter.
634 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
635 CountMap[S->getThen()] = ThenCount;
637 uint64_t OutCount = CurrentCount;
639 uint64_t ElseCount = ParentCount - ThenCount;
642 CountMap[S->getElse()] = ElseCount;
644 OutCount += CurrentCount;
646 OutCount += ElseCount;
648 RecordNextStmtCount = true;
651 void VisitCXXTryStmt(const CXXTryStmt *S) {
653 Visit(S->getTryBlock());
654 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
655 Visit(S->getHandler(I));
656 // Counter tracks the continuation block of the try statement.
657 setCount(PGO.getRegionCount(S));
658 RecordNextStmtCount = true;
661 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
662 RecordNextStmtCount = false;
663 // Counter tracks the catch statement's handler block.
664 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
665 CountMap[S] = CatchCount;
666 Visit(S->getHandlerBlock());
669 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
671 uint64_t ParentCount = CurrentCount;
674 // Counter tracks the "true" part of a conditional operator. The
675 // count in the "false" part will be calculated from this counter.
676 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
677 CountMap[E->getTrueExpr()] = TrueCount;
678 Visit(E->getTrueExpr());
679 uint64_t OutCount = CurrentCount;
681 uint64_t FalseCount = setCount(ParentCount - TrueCount);
682 CountMap[E->getFalseExpr()] = FalseCount;
683 Visit(E->getFalseExpr());
684 OutCount += CurrentCount;
687 RecordNextStmtCount = true;
690 void VisitBinLAnd(const BinaryOperator *E) {
692 uint64_t ParentCount = CurrentCount;
694 // Counter tracks the right hand side of a logical and operator.
695 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
696 CountMap[E->getRHS()] = RHSCount;
698 setCount(ParentCount + RHSCount - CurrentCount);
699 RecordNextStmtCount = true;
702 void VisitBinLOr(const BinaryOperator *E) {
704 uint64_t ParentCount = CurrentCount;
706 // Counter tracks the right hand side of a logical or operator.
707 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
708 CountMap[E->getRHS()] = RHSCount;
710 setCount(ParentCount + RHSCount - CurrentCount);
711 RecordNextStmtCount = true;
714 } // end anonymous namespace
716 void PGOHash::combine(HashType Type) {
717 // Check that we never combine 0 and only have six bits.
718 assert(Type && "Hash is invalid: unexpected type 0");
719 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
721 // Pass through MD5 if enough work has built up.
722 if (Count && Count % NumTypesPerWord == 0) {
723 using namespace llvm::support;
724 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
725 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
729 // Accumulate the current type.
731 Working = Working << NumBitsPerType | Type;
734 uint64_t PGOHash::finalize() {
735 // Use Working as the hash directly if we never used MD5.
736 if (Count <= NumTypesPerWord)
737 // No need to byte swap here, since none of the math was endian-dependent.
738 // This number will be byte-swapped as required on endianness transitions,
739 // so we will see the same value on the other side.
742 // Check for remaining work in Working.
746 // Finalize the MD5 and return the hash.
747 llvm::MD5::MD5Result Result;
749 using namespace llvm::support;
753 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
754 const Decl *D = GD.getDecl();
758 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
759 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
760 if (!InstrumentRegions && !PGOReader)
764 // Constructors and destructors may be represented by several functions in IR.
765 // If so, instrument only base variant, others are implemented by delegation
766 // to the base one, it would be counted twice otherwise.
767 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
768 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
771 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
772 if (GD.getCtorType() != Ctor_Base &&
773 CodeGenFunction::IsConstructorDelegationValid(CCD))
776 CGM.ClearUnusedCoverageMapping(D);
779 mapRegionCounters(D);
780 if (CGM.getCodeGenOpts().CoverageMapping)
781 emitCounterRegionMapping(D);
783 SourceManager &SM = CGM.getContext().getSourceManager();
784 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
785 computeRegionCounts(D);
786 applyFunctionAttributes(PGOReader, Fn);
790 void CodeGenPGO::mapRegionCounters(const Decl *D) {
791 // Use the latest hash version when inserting instrumentation, but use the
792 // version in the indexed profile if we're reading PGO data.
793 PGOHashVersion HashVersion = PGO_HASH_LATEST;
794 if (auto *PGOReader = CGM.getPGOReader())
795 HashVersion = getPGOHashVersion(PGOReader, CGM);
797 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
798 MapRegionCounters Walker(HashVersion, *RegionCounterMap);
799 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
800 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
801 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
802 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
803 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
804 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
805 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
806 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
807 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
808 NumRegionCounters = Walker.NextCounter;
809 FunctionHash = Walker.Hash.finalize();
812 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
816 // Don't map the functions in system headers.
817 const auto &SM = CGM.getContext().getSourceManager();
818 auto Loc = D->getBody()->getLocStart();
819 return SM.isInSystemHeader(Loc);
822 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
823 if (skipRegionMappingForDecl(D))
826 std::string CoverageMapping;
827 llvm::raw_string_ostream OS(CoverageMapping);
828 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
829 CGM.getContext().getSourceManager(),
830 CGM.getLangOpts(), RegionCounterMap.get());
831 MappingGen.emitCounterMapping(D, OS);
834 if (CoverageMapping.empty())
837 CGM.getCoverageMapping()->addFunctionMappingRecord(
838 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
842 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
843 llvm::GlobalValue::LinkageTypes Linkage) {
844 if (skipRegionMappingForDecl(D))
847 std::string CoverageMapping;
848 llvm::raw_string_ostream OS(CoverageMapping);
849 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
850 CGM.getContext().getSourceManager(),
852 MappingGen.emitEmptyMapping(D, OS);
855 if (CoverageMapping.empty())
858 setFuncName(Name, Linkage);
859 CGM.getCoverageMapping()->addFunctionMappingRecord(
860 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
863 void CodeGenPGO::computeRegionCounts(const Decl *D) {
864 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
865 ComputeRegionCounts Walker(*StmtCountMap, *this);
866 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
867 Walker.VisitFunctionDecl(FD);
868 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
869 Walker.VisitObjCMethodDecl(MD);
870 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
871 Walker.VisitBlockDecl(BD);
872 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
873 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
877 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
878 llvm::Function *Fn) {
879 if (!haveRegionCounts())
882 uint64_t FunctionCount = getRegionCount(nullptr);
883 Fn->setEntryCount(FunctionCount);
886 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
887 llvm::Value *StepV) {
888 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
890 if (!Builder.GetInsertBlock())
893 unsigned Counter = (*RegionCounterMap)[S];
894 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
896 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
897 Builder.getInt64(FunctionHash),
898 Builder.getInt32(NumRegionCounters),
899 Builder.getInt32(Counter), StepV};
901 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
902 makeArrayRef(Args, 4));
905 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
909 // This method either inserts a call to the profile run-time during
910 // instrumentation or puts profile data into metadata for PGO use.
911 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
912 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
914 if (!EnableValueProfiling)
917 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
920 if (isa<llvm::Constant>(ValuePtr))
923 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
924 if (InstrumentValueSites && RegionCounterMap) {
925 auto BuilderInsertPoint = Builder.saveIP();
926 Builder.SetInsertPoint(ValueSite);
927 llvm::Value *Args[5] = {
928 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
929 Builder.getInt64(FunctionHash),
930 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
931 Builder.getInt32(ValueKind),
932 Builder.getInt32(NumValueSites[ValueKind]++)
935 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
936 Builder.restoreIP(BuilderInsertPoint);
940 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
941 if (PGOReader && haveRegionCounts()) {
942 // We record the top most called three functions at each call site.
943 // Profile metadata contains "VP" string identifying this metadata
944 // as value profiling data, then a uint32_t value for the value profiling
945 // kind, a uint64_t value for the total number of times the call is
946 // executed, followed by the function hash and execution count (uint64_t)
947 // pairs for each function.
948 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
951 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
952 (llvm::InstrProfValueKind)ValueKind,
953 NumValueSites[ValueKind]);
955 NumValueSites[ValueKind]++;
959 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
961 CGM.getPGOStats().addVisited(IsInMainFile);
962 RegionCounts.clear();
963 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
964 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
965 if (auto E = RecordExpected.takeError()) {
966 auto IPE = llvm::InstrProfError::take(std::move(E));
967 if (IPE == llvm::instrprof_error::unknown_function)
968 CGM.getPGOStats().addMissing(IsInMainFile);
969 else if (IPE == llvm::instrprof_error::hash_mismatch)
970 CGM.getPGOStats().addMismatched(IsInMainFile);
971 else if (IPE == llvm::instrprof_error::malformed)
972 // TODO: Consider a more specific warning for this case.
973 CGM.getPGOStats().addMismatched(IsInMainFile);
977 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
978 RegionCounts = ProfRecord->Counts;
981 /// \brief Calculate what to divide by to scale weights.
983 /// Given the maximum weight, calculate a divisor that will scale all the
984 /// weights to strictly less than UINT32_MAX.
985 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
986 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
989 /// \brief Scale an individual branch weight (and add 1).
991 /// Scale a 64-bit weight down to 32-bits using \c Scale.
993 /// According to Laplace's Rule of Succession, it is better to compute the
994 /// weight based on the count plus 1, so universally add 1 to the value.
996 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
997 /// greater than \c Weight.
998 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
999 assert(Scale && "scale by 0?");
1000 uint64_t Scaled = Weight / Scale + 1;
1001 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1005 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1006 uint64_t FalseCount) {
1007 // Check for empty weights.
1008 if (!TrueCount && !FalseCount)
1011 // Calculate how to scale down to 32-bits.
1012 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1014 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1015 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1016 scaleBranchWeight(FalseCount, Scale));
1020 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1021 // We need at least two elements to create meaningful weights.
1022 if (Weights.size() < 2)
1025 // Check for empty weights.
1026 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1030 // Calculate how to scale down to 32-bits.
1031 uint64_t Scale = calculateWeightScale(MaxWeight);
1033 SmallVector<uint32_t, 16> ScaledWeights;
1034 ScaledWeights.reserve(Weights.size());
1035 for (uint64_t W : Weights)
1036 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1038 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1039 return MDHelper.createBranchWeights(ScaledWeights);
1042 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1043 uint64_t LoopCount) {
1044 if (!PGO.haveRegionCounts())
1046 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1047 assert(CondCount.hasValue() && "missing expected loop condition count");
1048 if (*CondCount == 0)
1050 return createProfileWeights(LoopCount,
1051 std::max(*CondCount, LoopCount) - LoopCount);