1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Instrumentation-based profile-guided optimization
11 //===----------------------------------------------------------------------===//
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/Endian.h"
21 #include "llvm/Support/FileSystem.h"
22 #include "llvm/Support/MD5.h"
24 static llvm::cl::opt<bool>
25 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
26 llvm::cl::desc("Enable value profiling"),
27 llvm::cl::Hidden, llvm::cl::init(false));
29 using namespace clang;
30 using namespace CodeGen;
32 void CodeGenPGO::setFuncName(StringRef Name,
33 llvm::GlobalValue::LinkageTypes Linkage) {
34 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
35 FuncName = llvm::getPGOFuncName(
36 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
37 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 // If we're generating a profile, create a variable for the name.
40 if (CGM.getCodeGenOpts().hasProfileClangInstr())
41 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
45 setFuncName(Fn->getName(), Fn->getLinkage());
46 // Create PGOFuncName meta data.
47 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
50 /// The version of the PGO hash algorithm.
51 enum PGOHashVersion : unsigned {
55 // Keep this set to the latest hash version.
56 PGO_HASH_LATEST = PGO_HASH_V2
60 /// Stable hasher for PGO region counters.
62 /// PGOHash produces a stable hash of a given function's control flow.
64 /// Changing the output of this hash will invalidate all previously generated
65 /// profiles -- i.e., don't do it.
67 /// \note When this hash does eventually change (years?), we still need to
68 /// support old hashes. We'll need to pull in the version number from the
69 /// profile data format and use the matching hash function.
73 PGOHashVersion HashVersion;
76 static const int NumBitsPerType = 6;
77 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
78 static const unsigned TooBig = 1u << NumBitsPerType;
81 /// Hash values for AST nodes.
83 /// Distinct values for AST nodes that have region counters attached.
85 /// These values must be stable. All new members must be added at the end,
86 /// and no members should be removed. Changing the enumeration value for an
87 /// AST node will affect the hash of every function that contains that node.
88 enum HashType : unsigned char {
95 ObjCForCollectionStmt,
105 BinaryConditionalOperator,
106 // The preceding values are available with PGO_HASH_V1.
124 // The preceding values are available with PGO_HASH_V2.
126 // Keep this last. It's for the static assert that follows.
129 static_assert(LastHashType <= TooBig, "Too many types in HashType");
131 PGOHash(PGOHashVersion HashVersion)
132 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
133 void combine(HashType Type);
135 PGOHashVersion getHashVersion() const { return HashVersion; }
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
141 /// Get the PGO hash version used in the given indexed profile.
142 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
143 CodeGenModule &CGM) {
144 if (PGOReader->getVersion() <= 4)
149 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
150 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
151 using Base = RecursiveASTVisitor<MapRegionCounters>;
153 /// The next counter value to assign.
154 unsigned NextCounter;
155 /// The function hash.
157 /// The map of statements to counters.
158 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160 MapRegionCounters(PGOHashVersion HashVersion,
161 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
162 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164 // Blocks and lambdas are handled as separate functions, so we need not
165 // traverse them in the parent context.
166 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
167 bool TraverseLambdaExpr(LambdaExpr *LE) {
168 // Traverse the captures, but not the body.
169 for (const auto &C : zip(LE->captures(), LE->capture_inits()))
170 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
173 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
175 bool VisitDecl(const Decl *D) {
176 switch (D->getKind()) {
180 case Decl::CXXMethod:
181 case Decl::CXXConstructor:
182 case Decl::CXXDestructor:
183 case Decl::CXXConversion:
184 case Decl::ObjCMethod:
187 CounterMap[D->getBody()] = NextCounter++;
193 /// If \p S gets a fresh counter, update the counter mappings. Return the
195 PGOHash::HashType updateCounterMappings(Stmt *S) {
196 auto Type = getHashType(PGO_HASH_V1, S);
197 if (Type != PGOHash::None)
198 CounterMap[S] = NextCounter++;
202 /// Include \p S in the function hash.
203 bool VisitStmt(Stmt *S) {
204 auto Type = updateCounterMappings(S);
205 if (Hash.getHashVersion() != PGO_HASH_V1)
206 Type = getHashType(Hash.getHashVersion(), S);
207 if (Type != PGOHash::None)
212 bool TraverseIfStmt(IfStmt *If) {
213 // If we used the V1 hash, use the default traversal.
214 if (Hash.getHashVersion() == PGO_HASH_V1)
215 return Base::TraverseIfStmt(If);
217 // Otherwise, keep track of which branch we're in while traversing.
219 for (Stmt *CS : If->children()) {
222 if (CS == If->getThen())
223 Hash.combine(PGOHash::IfThenBranch);
224 else if (CS == If->getElse())
225 Hash.combine(PGOHash::IfElseBranch);
228 Hash.combine(PGOHash::EndOfScope);
232 // If the statement type \p N is nestable, and its nesting impacts profile
233 // stability, define a custom traversal which tracks the end of the statement
234 // in the hash (provided we're not using the V1 hash).
235 #define DEFINE_NESTABLE_TRAVERSAL(N) \
236 bool Traverse##N(N *S) { \
237 Base::Traverse##N(S); \
238 if (Hash.getHashVersion() != PGO_HASH_V1) \
239 Hash.combine(PGOHash::EndOfScope); \
243 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
244 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
245 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
246 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
247 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
248 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
249 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
251 /// Get version \p HashVersion of the PGO hash for \p S.
252 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
253 switch (S->getStmtClass()) {
256 case Stmt::LabelStmtClass:
257 return PGOHash::LabelStmt;
258 case Stmt::WhileStmtClass:
259 return PGOHash::WhileStmt;
260 case Stmt::DoStmtClass:
261 return PGOHash::DoStmt;
262 case Stmt::ForStmtClass:
263 return PGOHash::ForStmt;
264 case Stmt::CXXForRangeStmtClass:
265 return PGOHash::CXXForRangeStmt;
266 case Stmt::ObjCForCollectionStmtClass:
267 return PGOHash::ObjCForCollectionStmt;
268 case Stmt::SwitchStmtClass:
269 return PGOHash::SwitchStmt;
270 case Stmt::CaseStmtClass:
271 return PGOHash::CaseStmt;
272 case Stmt::DefaultStmtClass:
273 return PGOHash::DefaultStmt;
274 case Stmt::IfStmtClass:
275 return PGOHash::IfStmt;
276 case Stmt::CXXTryStmtClass:
277 return PGOHash::CXXTryStmt;
278 case Stmt::CXXCatchStmtClass:
279 return PGOHash::CXXCatchStmt;
280 case Stmt::ConditionalOperatorClass:
281 return PGOHash::ConditionalOperator;
282 case Stmt::BinaryConditionalOperatorClass:
283 return PGOHash::BinaryConditionalOperator;
284 case Stmt::BinaryOperatorClass: {
285 const BinaryOperator *BO = cast<BinaryOperator>(S);
286 if (BO->getOpcode() == BO_LAnd)
287 return PGOHash::BinaryOperatorLAnd;
288 if (BO->getOpcode() == BO_LOr)
289 return PGOHash::BinaryOperatorLOr;
290 if (HashVersion == PGO_HASH_V2) {
291 switch (BO->getOpcode()) {
295 return PGOHash::BinaryOperatorLT;
297 return PGOHash::BinaryOperatorGT;
299 return PGOHash::BinaryOperatorLE;
301 return PGOHash::BinaryOperatorGE;
303 return PGOHash::BinaryOperatorEQ;
305 return PGOHash::BinaryOperatorNE;
312 if (HashVersion == PGO_HASH_V2) {
313 switch (S->getStmtClass()) {
316 case Stmt::GotoStmtClass:
317 return PGOHash::GotoStmt;
318 case Stmt::IndirectGotoStmtClass:
319 return PGOHash::IndirectGotoStmt;
320 case Stmt::BreakStmtClass:
321 return PGOHash::BreakStmt;
322 case Stmt::ContinueStmtClass:
323 return PGOHash::ContinueStmt;
324 case Stmt::ReturnStmtClass:
325 return PGOHash::ReturnStmt;
326 case Stmt::CXXThrowExprClass:
327 return PGOHash::ThrowExpr;
328 case Stmt::UnaryOperatorClass: {
329 const UnaryOperator *UO = cast<UnaryOperator>(S);
330 if (UO->getOpcode() == UO_LNot)
331 return PGOHash::UnaryOperatorLNot;
337 return PGOHash::None;
341 /// A StmtVisitor that propagates the raw counts through the AST and
342 /// records the count at statements where the value may change.
343 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
347 /// A flag that is set when the current count should be recorded on the
348 /// next statement, such as at the exit of a loop.
349 bool RecordNextStmtCount;
351 /// The count at the current location in the traversal.
352 uint64_t CurrentCount;
354 /// The map of statements to count values.
355 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
357 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
358 struct BreakContinue {
360 uint64_t ContinueCount;
361 BreakContinue() : BreakCount(0), ContinueCount(0) {}
363 SmallVector<BreakContinue, 8> BreakContinueStack;
365 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
367 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
369 void RecordStmtCount(const Stmt *S) {
370 if (RecordNextStmtCount) {
371 CountMap[S] = CurrentCount;
372 RecordNextStmtCount = false;
376 /// Set and return the current count.
377 uint64_t setCount(uint64_t Count) {
378 CurrentCount = Count;
382 void VisitStmt(const Stmt *S) {
384 for (const Stmt *Child : S->children())
389 void VisitFunctionDecl(const FunctionDecl *D) {
390 // Counter tracks entry to the function body.
391 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
392 CountMap[D->getBody()] = BodyCount;
396 // Skip lambda expressions. We visit these as FunctionDecls when we're
397 // generating them and aren't interested in the body when generating a
399 void VisitLambdaExpr(const LambdaExpr *LE) {}
401 void VisitCapturedDecl(const CapturedDecl *D) {
402 // Counter tracks entry to the capture body.
403 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
404 CountMap[D->getBody()] = BodyCount;
408 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
409 // Counter tracks entry to the method body.
410 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411 CountMap[D->getBody()] = BodyCount;
415 void VisitBlockDecl(const BlockDecl *D) {
416 // Counter tracks entry to the block body.
417 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
418 CountMap[D->getBody()] = BodyCount;
422 void VisitReturnStmt(const ReturnStmt *S) {
424 if (S->getRetValue())
425 Visit(S->getRetValue());
427 RecordNextStmtCount = true;
430 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
433 Visit(E->getSubExpr());
435 RecordNextStmtCount = true;
438 void VisitGotoStmt(const GotoStmt *S) {
441 RecordNextStmtCount = true;
444 void VisitLabelStmt(const LabelStmt *S) {
445 RecordNextStmtCount = false;
446 // Counter tracks the block following the label.
447 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
448 CountMap[S] = BlockCount;
449 Visit(S->getSubStmt());
452 void VisitBreakStmt(const BreakStmt *S) {
454 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
455 BreakContinueStack.back().BreakCount += CurrentCount;
457 RecordNextStmtCount = true;
460 void VisitContinueStmt(const ContinueStmt *S) {
462 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
463 BreakContinueStack.back().ContinueCount += CurrentCount;
465 RecordNextStmtCount = true;
468 void VisitWhileStmt(const WhileStmt *S) {
470 uint64_t ParentCount = CurrentCount;
472 BreakContinueStack.push_back(BreakContinue());
473 // Visit the body region first so the break/continue adjustments can be
474 // included when visiting the condition.
475 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
476 CountMap[S->getBody()] = CurrentCount;
478 uint64_t BackedgeCount = CurrentCount;
480 // ...then go back and propagate counts through the condition. The count
481 // at the start of the condition is the sum of the incoming edges,
482 // the backedge from the end of the loop body, and the edges from
483 // continue statements.
484 BreakContinue BC = BreakContinueStack.pop_back_val();
486 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
487 CountMap[S->getCond()] = CondCount;
489 setCount(BC.BreakCount + CondCount - BodyCount);
490 RecordNextStmtCount = true;
493 void VisitDoStmt(const DoStmt *S) {
495 uint64_t LoopCount = PGO.getRegionCount(S);
497 BreakContinueStack.push_back(BreakContinue());
498 // The count doesn't include the fallthrough from the parent scope. Add it.
499 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
500 CountMap[S->getBody()] = BodyCount;
502 uint64_t BackedgeCount = CurrentCount;
504 BreakContinue BC = BreakContinueStack.pop_back_val();
505 // The count at the start of the condition is equal to the count at the
506 // end of the body, plus any continues.
507 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
508 CountMap[S->getCond()] = CondCount;
510 setCount(BC.BreakCount + CondCount - LoopCount);
511 RecordNextStmtCount = true;
514 void VisitForStmt(const ForStmt *S) {
519 uint64_t ParentCount = CurrentCount;
521 BreakContinueStack.push_back(BreakContinue());
522 // Visit the body region first. (This is basically the same as a while
523 // loop; see further comments in VisitWhileStmt.)
524 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
525 CountMap[S->getBody()] = BodyCount;
527 uint64_t BackedgeCount = CurrentCount;
528 BreakContinue BC = BreakContinueStack.pop_back_val();
530 // The increment is essentially part of the body but it needs to include
531 // the count for all the continue statements.
533 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
534 CountMap[S->getInc()] = IncCount;
538 // ...then go back and propagate counts through the condition.
540 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
542 CountMap[S->getCond()] = CondCount;
545 setCount(BC.BreakCount + CondCount - BodyCount);
546 RecordNextStmtCount = true;
549 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
553 Visit(S->getLoopVarStmt());
554 Visit(S->getRangeStmt());
555 Visit(S->getBeginStmt());
556 Visit(S->getEndStmt());
558 uint64_t ParentCount = CurrentCount;
559 BreakContinueStack.push_back(BreakContinue());
560 // Visit the body region first. (This is basically the same as a while
561 // loop; see further comments in VisitWhileStmt.)
562 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
563 CountMap[S->getBody()] = BodyCount;
565 uint64_t BackedgeCount = CurrentCount;
566 BreakContinue BC = BreakContinueStack.pop_back_val();
568 // The increment is essentially part of the body but it needs to include
569 // the count for all the continue statements.
570 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
571 CountMap[S->getInc()] = IncCount;
574 // ...then go back and propagate counts through the condition.
576 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
577 CountMap[S->getCond()] = CondCount;
579 setCount(BC.BreakCount + CondCount - BodyCount);
580 RecordNextStmtCount = true;
583 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
585 Visit(S->getElement());
586 uint64_t ParentCount = CurrentCount;
587 BreakContinueStack.push_back(BreakContinue());
588 // Counter tracks the body of the loop.
589 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
590 CountMap[S->getBody()] = BodyCount;
592 uint64_t BackedgeCount = CurrentCount;
593 BreakContinue BC = BreakContinueStack.pop_back_val();
595 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
597 RecordNextStmtCount = true;
600 void VisitSwitchStmt(const SwitchStmt *S) {
606 BreakContinueStack.push_back(BreakContinue());
608 // If the switch is inside a loop, add the continue counts.
609 BreakContinue BC = BreakContinueStack.pop_back_val();
610 if (!BreakContinueStack.empty())
611 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
612 // Counter tracks the exit block of the switch.
613 setCount(PGO.getRegionCount(S));
614 RecordNextStmtCount = true;
617 void VisitSwitchCase(const SwitchCase *S) {
618 RecordNextStmtCount = false;
619 // Counter for this particular case. This counts only jumps from the
620 // switch header and does not include fallthrough from the case before
622 uint64_t CaseCount = PGO.getRegionCount(S);
623 setCount(CurrentCount + CaseCount);
624 // We need the count without fallthrough in the mapping, so it's more useful
625 // for branch probabilities.
626 CountMap[S] = CaseCount;
627 RecordNextStmtCount = true;
628 Visit(S->getSubStmt());
631 void VisitIfStmt(const IfStmt *S) {
633 uint64_t ParentCount = CurrentCount;
638 // Counter tracks the "then" part of an if statement. The count for
639 // the "else" part, if it exists, will be calculated from this counter.
640 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
641 CountMap[S->getThen()] = ThenCount;
643 uint64_t OutCount = CurrentCount;
645 uint64_t ElseCount = ParentCount - ThenCount;
648 CountMap[S->getElse()] = ElseCount;
650 OutCount += CurrentCount;
652 OutCount += ElseCount;
654 RecordNextStmtCount = true;
657 void VisitCXXTryStmt(const CXXTryStmt *S) {
659 Visit(S->getTryBlock());
660 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
661 Visit(S->getHandler(I));
662 // Counter tracks the continuation block of the try statement.
663 setCount(PGO.getRegionCount(S));
664 RecordNextStmtCount = true;
667 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
668 RecordNextStmtCount = false;
669 // Counter tracks the catch statement's handler block.
670 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
671 CountMap[S] = CatchCount;
672 Visit(S->getHandlerBlock());
675 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
677 uint64_t ParentCount = CurrentCount;
680 // Counter tracks the "true" part of a conditional operator. The
681 // count in the "false" part will be calculated from this counter.
682 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
683 CountMap[E->getTrueExpr()] = TrueCount;
684 Visit(E->getTrueExpr());
685 uint64_t OutCount = CurrentCount;
687 uint64_t FalseCount = setCount(ParentCount - TrueCount);
688 CountMap[E->getFalseExpr()] = FalseCount;
689 Visit(E->getFalseExpr());
690 OutCount += CurrentCount;
693 RecordNextStmtCount = true;
696 void VisitBinLAnd(const BinaryOperator *E) {
698 uint64_t ParentCount = CurrentCount;
700 // Counter tracks the right hand side of a logical and operator.
701 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
702 CountMap[E->getRHS()] = RHSCount;
704 setCount(ParentCount + RHSCount - CurrentCount);
705 RecordNextStmtCount = true;
708 void VisitBinLOr(const BinaryOperator *E) {
710 uint64_t ParentCount = CurrentCount;
712 // Counter tracks the right hand side of a logical or operator.
713 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
714 CountMap[E->getRHS()] = RHSCount;
716 setCount(ParentCount + RHSCount - CurrentCount);
717 RecordNextStmtCount = true;
720 } // end anonymous namespace
722 void PGOHash::combine(HashType Type) {
723 // Check that we never combine 0 and only have six bits.
724 assert(Type && "Hash is invalid: unexpected type 0");
725 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
727 // Pass through MD5 if enough work has built up.
728 if (Count && Count % NumTypesPerWord == 0) {
729 using namespace llvm::support;
730 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
731 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
735 // Accumulate the current type.
737 Working = Working << NumBitsPerType | Type;
740 uint64_t PGOHash::finalize() {
741 // Use Working as the hash directly if we never used MD5.
742 if (Count <= NumTypesPerWord)
743 // No need to byte swap here, since none of the math was endian-dependent.
744 // This number will be byte-swapped as required on endianness transitions,
745 // so we will see the same value on the other side.
748 // Check for remaining work in Working.
752 // Finalize the MD5 and return the hash.
753 llvm::MD5::MD5Result Result;
755 using namespace llvm::support;
759 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
760 const Decl *D = GD.getDecl();
764 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
765 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
766 if (!InstrumentRegions && !PGOReader)
770 // Constructors and destructors may be represented by several functions in IR.
771 // If so, instrument only base variant, others are implemented by delegation
772 // to the base one, it would be counted twice otherwise.
773 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
774 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
775 if (GD.getCtorType() != Ctor_Base &&
776 CodeGenFunction::IsConstructorDelegationValid(CCD))
779 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
782 CGM.ClearUnusedCoverageMapping(D);
785 mapRegionCounters(D);
786 if (CGM.getCodeGenOpts().CoverageMapping)
787 emitCounterRegionMapping(D);
789 SourceManager &SM = CGM.getContext().getSourceManager();
790 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
791 computeRegionCounts(D);
792 applyFunctionAttributes(PGOReader, Fn);
796 void CodeGenPGO::mapRegionCounters(const Decl *D) {
797 // Use the latest hash version when inserting instrumentation, but use the
798 // version in the indexed profile if we're reading PGO data.
799 PGOHashVersion HashVersion = PGO_HASH_LATEST;
800 if (auto *PGOReader = CGM.getPGOReader())
801 HashVersion = getPGOHashVersion(PGOReader, CGM);
803 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
804 MapRegionCounters Walker(HashVersion, *RegionCounterMap);
805 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
806 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
807 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
808 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
809 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
810 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
811 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
812 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
813 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
814 NumRegionCounters = Walker.NextCounter;
815 FunctionHash = Walker.Hash.finalize();
818 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
822 // Don't map the functions in system headers.
823 const auto &SM = CGM.getContext().getSourceManager();
824 auto Loc = D->getBody()->getBeginLoc();
825 return SM.isInSystemHeader(Loc);
828 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
829 if (skipRegionMappingForDecl(D))
832 std::string CoverageMapping;
833 llvm::raw_string_ostream OS(CoverageMapping);
834 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
835 CGM.getContext().getSourceManager(),
836 CGM.getLangOpts(), RegionCounterMap.get());
837 MappingGen.emitCounterMapping(D, OS);
840 if (CoverageMapping.empty())
843 CGM.getCoverageMapping()->addFunctionMappingRecord(
844 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
848 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
849 llvm::GlobalValue::LinkageTypes Linkage) {
850 if (skipRegionMappingForDecl(D))
853 std::string CoverageMapping;
854 llvm::raw_string_ostream OS(CoverageMapping);
855 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
856 CGM.getContext().getSourceManager(),
858 MappingGen.emitEmptyMapping(D, OS);
861 if (CoverageMapping.empty())
864 setFuncName(Name, Linkage);
865 CGM.getCoverageMapping()->addFunctionMappingRecord(
866 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
869 void CodeGenPGO::computeRegionCounts(const Decl *D) {
870 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
871 ComputeRegionCounts Walker(*StmtCountMap, *this);
872 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
873 Walker.VisitFunctionDecl(FD);
874 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
875 Walker.VisitObjCMethodDecl(MD);
876 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
877 Walker.VisitBlockDecl(BD);
878 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
879 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
883 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
884 llvm::Function *Fn) {
885 if (!haveRegionCounts())
888 uint64_t FunctionCount = getRegionCount(nullptr);
889 Fn->setEntryCount(FunctionCount);
892 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
893 llvm::Value *StepV) {
894 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
896 if (!Builder.GetInsertBlock())
899 unsigned Counter = (*RegionCounterMap)[S];
900 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
902 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
903 Builder.getInt64(FunctionHash),
904 Builder.getInt32(NumRegionCounters),
905 Builder.getInt32(Counter), StepV};
907 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
908 makeArrayRef(Args, 4));
911 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
915 // This method either inserts a call to the profile run-time during
916 // instrumentation or puts profile data into metadata for PGO use.
917 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
918 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
920 if (!EnableValueProfiling)
923 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
926 if (isa<llvm::Constant>(ValuePtr))
929 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
930 if (InstrumentValueSites && RegionCounterMap) {
931 auto BuilderInsertPoint = Builder.saveIP();
932 Builder.SetInsertPoint(ValueSite);
933 llvm::Value *Args[5] = {
934 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
935 Builder.getInt64(FunctionHash),
936 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
937 Builder.getInt32(ValueKind),
938 Builder.getInt32(NumValueSites[ValueKind]++)
941 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
942 Builder.restoreIP(BuilderInsertPoint);
946 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
947 if (PGOReader && haveRegionCounts()) {
948 // We record the top most called three functions at each call site.
949 // Profile metadata contains "VP" string identifying this metadata
950 // as value profiling data, then a uint32_t value for the value profiling
951 // kind, a uint64_t value for the total number of times the call is
952 // executed, followed by the function hash and execution count (uint64_t)
953 // pairs for each function.
954 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
957 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
958 (llvm::InstrProfValueKind)ValueKind,
959 NumValueSites[ValueKind]);
961 NumValueSites[ValueKind]++;
965 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
967 CGM.getPGOStats().addVisited(IsInMainFile);
968 RegionCounts.clear();
969 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
970 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
971 if (auto E = RecordExpected.takeError()) {
972 auto IPE = llvm::InstrProfError::take(std::move(E));
973 if (IPE == llvm::instrprof_error::unknown_function)
974 CGM.getPGOStats().addMissing(IsInMainFile);
975 else if (IPE == llvm::instrprof_error::hash_mismatch)
976 CGM.getPGOStats().addMismatched(IsInMainFile);
977 else if (IPE == llvm::instrprof_error::malformed)
978 // TODO: Consider a more specific warning for this case.
979 CGM.getPGOStats().addMismatched(IsInMainFile);
983 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
984 RegionCounts = ProfRecord->Counts;
987 /// Calculate what to divide by to scale weights.
989 /// Given the maximum weight, calculate a divisor that will scale all the
990 /// weights to strictly less than UINT32_MAX.
991 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
992 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
995 /// Scale an individual branch weight (and add 1).
997 /// Scale a 64-bit weight down to 32-bits using \c Scale.
999 /// According to Laplace's Rule of Succession, it is better to compute the
1000 /// weight based on the count plus 1, so universally add 1 to the value.
1002 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1003 /// greater than \c Weight.
1004 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1005 assert(Scale && "scale by 0?");
1006 uint64_t Scaled = Weight / Scale + 1;
1007 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1011 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1012 uint64_t FalseCount) {
1013 // Check for empty weights.
1014 if (!TrueCount && !FalseCount)
1017 // Calculate how to scale down to 32-bits.
1018 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1020 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1021 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1022 scaleBranchWeight(FalseCount, Scale));
1026 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1027 // We need at least two elements to create meaningful weights.
1028 if (Weights.size() < 2)
1031 // Check for empty weights.
1032 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1036 // Calculate how to scale down to 32-bits.
1037 uint64_t Scale = calculateWeightScale(MaxWeight);
1039 SmallVector<uint32_t, 16> ScaledWeights;
1040 ScaledWeights.reserve(Weights.size());
1041 for (uint64_t W : Weights)
1042 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1044 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1045 return MDHelper.createBranchWeights(ScaledWeights);
1048 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1049 uint64_t LoopCount) {
1050 if (!PGO.haveRegionCounts())
1052 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1053 assert(CondCount.hasValue() && "missing expected loop condition count");
1054 if (*CondCount == 0)
1056 return createProfileWeights(LoopCount,
1057 std::max(*CondCount, LoopCount) - LoopCount);