1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Instrumentation-based profile-guided optimization
11 //===----------------------------------------------------------------------===//
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
25 static llvm::cl::opt<bool>
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(false));
30 using namespace clang;
31 using namespace CodeGen;
33 void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36 FuncName = llvm::getPGOFuncName(
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40 // If we're generating a profile, create a variable for the name.
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
47 // Create PGOFuncName meta data.
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
57 // Keep this set to the latest hash version.
58 PGO_HASH_LATEST = PGO_HASH_V3
62 /// Stable hasher for PGO region counters.
64 /// PGOHash produces a stable hash of a given function's control flow.
66 /// Changing the output of this hash will invalidate all previously generated
67 /// profiles -- i.e., don't do it.
69 /// \note When this hash does eventually change (years?), we still need to
70 /// support old hashes. We'll need to pull in the version number from the
71 /// profile data format and use the matching hash function.
75 PGOHashVersion HashVersion;
78 static const int NumBitsPerType = 6;
79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80 static const unsigned TooBig = 1u << NumBitsPerType;
83 /// Hash values for AST nodes.
85 /// Distinct values for AST nodes that have region counters attached.
87 /// These values must be stable. All new members must be added at the end,
88 /// and no members should be removed. Changing the enumeration value for an
89 /// AST node will affect the hash of every function that contains that node.
90 enum HashType : unsigned char {
97 ObjCForCollectionStmt,
107 BinaryConditionalOperator,
108 // The preceding values are available with PGO_HASH_V1.
126 // The preceding values are available since PGO_HASH_V2.
128 // Keep this last. It's for the static assert that follows.
131 static_assert(LastHashType <= TooBig, "Too many types in HashType");
133 PGOHash(PGOHashVersion HashVersion)
134 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135 void combine(HashType Type);
137 PGOHashVersion getHashVersion() const { return HashVersion; }
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
143 /// Get the PGO hash version used in the given indexed profile.
144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145 CodeGenModule &CGM) {
146 if (PGOReader->getVersion() <= 4)
148 if (PGOReader->getVersion() <= 5)
153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155 using Base = RecursiveASTVisitor<MapRegionCounters>;
157 /// The next counter value to assign.
158 unsigned NextCounter;
159 /// The function hash.
161 /// The map of statements to counters.
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164 MapRegionCounters(PGOHashVersion HashVersion,
165 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
166 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
168 // Blocks and lambdas are handled as separate functions, so we need not
169 // traverse them in the parent context.
170 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
171 bool TraverseLambdaExpr(LambdaExpr *LE) {
172 // Traverse the captures, but not the body.
173 for (auto C : zip(LE->captures(), LE->capture_inits()))
174 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
177 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
179 bool VisitDecl(const Decl *D) {
180 switch (D->getKind()) {
184 case Decl::CXXMethod:
185 case Decl::CXXConstructor:
186 case Decl::CXXDestructor:
187 case Decl::CXXConversion:
188 case Decl::ObjCMethod:
191 CounterMap[D->getBody()] = NextCounter++;
197 /// If \p S gets a fresh counter, update the counter mappings. Return the
199 PGOHash::HashType updateCounterMappings(Stmt *S) {
200 auto Type = getHashType(PGO_HASH_V1, S);
201 if (Type != PGOHash::None)
202 CounterMap[S] = NextCounter++;
206 /// Include \p S in the function hash.
207 bool VisitStmt(Stmt *S) {
208 auto Type = updateCounterMappings(S);
209 if (Hash.getHashVersion() != PGO_HASH_V1)
210 Type = getHashType(Hash.getHashVersion(), S);
211 if (Type != PGOHash::None)
216 bool TraverseIfStmt(IfStmt *If) {
217 // If we used the V1 hash, use the default traversal.
218 if (Hash.getHashVersion() == PGO_HASH_V1)
219 return Base::TraverseIfStmt(If);
221 // Otherwise, keep track of which branch we're in while traversing.
223 for (Stmt *CS : If->children()) {
226 if (CS == If->getThen())
227 Hash.combine(PGOHash::IfThenBranch);
228 else if (CS == If->getElse())
229 Hash.combine(PGOHash::IfElseBranch);
232 Hash.combine(PGOHash::EndOfScope);
236 // If the statement type \p N is nestable, and its nesting impacts profile
237 // stability, define a custom traversal which tracks the end of the statement
238 // in the hash (provided we're not using the V1 hash).
239 #define DEFINE_NESTABLE_TRAVERSAL(N) \
240 bool Traverse##N(N *S) { \
241 Base::Traverse##N(S); \
242 if (Hash.getHashVersion() != PGO_HASH_V1) \
243 Hash.combine(PGOHash::EndOfScope); \
247 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
248 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
249 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
250 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
251 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
252 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
253 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
255 /// Get version \p HashVersion of the PGO hash for \p S.
256 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
257 switch (S->getStmtClass()) {
260 case Stmt::LabelStmtClass:
261 return PGOHash::LabelStmt;
262 case Stmt::WhileStmtClass:
263 return PGOHash::WhileStmt;
264 case Stmt::DoStmtClass:
265 return PGOHash::DoStmt;
266 case Stmt::ForStmtClass:
267 return PGOHash::ForStmt;
268 case Stmt::CXXForRangeStmtClass:
269 return PGOHash::CXXForRangeStmt;
270 case Stmt::ObjCForCollectionStmtClass:
271 return PGOHash::ObjCForCollectionStmt;
272 case Stmt::SwitchStmtClass:
273 return PGOHash::SwitchStmt;
274 case Stmt::CaseStmtClass:
275 return PGOHash::CaseStmt;
276 case Stmt::DefaultStmtClass:
277 return PGOHash::DefaultStmt;
278 case Stmt::IfStmtClass:
279 return PGOHash::IfStmt;
280 case Stmt::CXXTryStmtClass:
281 return PGOHash::CXXTryStmt;
282 case Stmt::CXXCatchStmtClass:
283 return PGOHash::CXXCatchStmt;
284 case Stmt::ConditionalOperatorClass:
285 return PGOHash::ConditionalOperator;
286 case Stmt::BinaryConditionalOperatorClass:
287 return PGOHash::BinaryConditionalOperator;
288 case Stmt::BinaryOperatorClass: {
289 const BinaryOperator *BO = cast<BinaryOperator>(S);
290 if (BO->getOpcode() == BO_LAnd)
291 return PGOHash::BinaryOperatorLAnd;
292 if (BO->getOpcode() == BO_LOr)
293 return PGOHash::BinaryOperatorLOr;
294 if (HashVersion >= PGO_HASH_V2) {
295 switch (BO->getOpcode()) {
299 return PGOHash::BinaryOperatorLT;
301 return PGOHash::BinaryOperatorGT;
303 return PGOHash::BinaryOperatorLE;
305 return PGOHash::BinaryOperatorGE;
307 return PGOHash::BinaryOperatorEQ;
309 return PGOHash::BinaryOperatorNE;
316 if (HashVersion >= PGO_HASH_V2) {
317 switch (S->getStmtClass()) {
320 case Stmt::GotoStmtClass:
321 return PGOHash::GotoStmt;
322 case Stmt::IndirectGotoStmtClass:
323 return PGOHash::IndirectGotoStmt;
324 case Stmt::BreakStmtClass:
325 return PGOHash::BreakStmt;
326 case Stmt::ContinueStmtClass:
327 return PGOHash::ContinueStmt;
328 case Stmt::ReturnStmtClass:
329 return PGOHash::ReturnStmt;
330 case Stmt::CXXThrowExprClass:
331 return PGOHash::ThrowExpr;
332 case Stmt::UnaryOperatorClass: {
333 const UnaryOperator *UO = cast<UnaryOperator>(S);
334 if (UO->getOpcode() == UO_LNot)
335 return PGOHash::UnaryOperatorLNot;
341 return PGOHash::None;
345 /// A StmtVisitor that propagates the raw counts through the AST and
346 /// records the count at statements where the value may change.
347 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
351 /// A flag that is set when the current count should be recorded on the
352 /// next statement, such as at the exit of a loop.
353 bool RecordNextStmtCount;
355 /// The count at the current location in the traversal.
356 uint64_t CurrentCount;
358 /// The map of statements to count values.
359 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
361 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
362 struct BreakContinue {
364 uint64_t ContinueCount;
365 BreakContinue() : BreakCount(0), ContinueCount(0) {}
367 SmallVector<BreakContinue, 8> BreakContinueStack;
369 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
371 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
373 void RecordStmtCount(const Stmt *S) {
374 if (RecordNextStmtCount) {
375 CountMap[S] = CurrentCount;
376 RecordNextStmtCount = false;
380 /// Set and return the current count.
381 uint64_t setCount(uint64_t Count) {
382 CurrentCount = Count;
386 void VisitStmt(const Stmt *S) {
388 for (const Stmt *Child : S->children())
393 void VisitFunctionDecl(const FunctionDecl *D) {
394 // Counter tracks entry to the function body.
395 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
396 CountMap[D->getBody()] = BodyCount;
400 // Skip lambda expressions. We visit these as FunctionDecls when we're
401 // generating them and aren't interested in the body when generating a
403 void VisitLambdaExpr(const LambdaExpr *LE) {}
405 void VisitCapturedDecl(const CapturedDecl *D) {
406 // Counter tracks entry to the capture body.
407 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
408 CountMap[D->getBody()] = BodyCount;
412 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
413 // Counter tracks entry to the method body.
414 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
415 CountMap[D->getBody()] = BodyCount;
419 void VisitBlockDecl(const BlockDecl *D) {
420 // Counter tracks entry to the block body.
421 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
422 CountMap[D->getBody()] = BodyCount;
426 void VisitReturnStmt(const ReturnStmt *S) {
428 if (S->getRetValue())
429 Visit(S->getRetValue());
431 RecordNextStmtCount = true;
434 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
437 Visit(E->getSubExpr());
439 RecordNextStmtCount = true;
442 void VisitGotoStmt(const GotoStmt *S) {
445 RecordNextStmtCount = true;
448 void VisitLabelStmt(const LabelStmt *S) {
449 RecordNextStmtCount = false;
450 // Counter tracks the block following the label.
451 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
452 CountMap[S] = BlockCount;
453 Visit(S->getSubStmt());
456 void VisitBreakStmt(const BreakStmt *S) {
458 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459 BreakContinueStack.back().BreakCount += CurrentCount;
461 RecordNextStmtCount = true;
464 void VisitContinueStmt(const ContinueStmt *S) {
466 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467 BreakContinueStack.back().ContinueCount += CurrentCount;
469 RecordNextStmtCount = true;
472 void VisitWhileStmt(const WhileStmt *S) {
474 uint64_t ParentCount = CurrentCount;
476 BreakContinueStack.push_back(BreakContinue());
477 // Visit the body region first so the break/continue adjustments can be
478 // included when visiting the condition.
479 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
480 CountMap[S->getBody()] = CurrentCount;
482 uint64_t BackedgeCount = CurrentCount;
484 // ...then go back and propagate counts through the condition. The count
485 // at the start of the condition is the sum of the incoming edges,
486 // the backedge from the end of the loop body, and the edges from
487 // continue statements.
488 BreakContinue BC = BreakContinueStack.pop_back_val();
490 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
491 CountMap[S->getCond()] = CondCount;
493 setCount(BC.BreakCount + CondCount - BodyCount);
494 RecordNextStmtCount = true;
497 void VisitDoStmt(const DoStmt *S) {
499 uint64_t LoopCount = PGO.getRegionCount(S);
501 BreakContinueStack.push_back(BreakContinue());
502 // The count doesn't include the fallthrough from the parent scope. Add it.
503 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
504 CountMap[S->getBody()] = BodyCount;
506 uint64_t BackedgeCount = CurrentCount;
508 BreakContinue BC = BreakContinueStack.pop_back_val();
509 // The count at the start of the condition is equal to the count at the
510 // end of the body, plus any continues.
511 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
512 CountMap[S->getCond()] = CondCount;
514 setCount(BC.BreakCount + CondCount - LoopCount);
515 RecordNextStmtCount = true;
518 void VisitForStmt(const ForStmt *S) {
523 uint64_t ParentCount = CurrentCount;
525 BreakContinueStack.push_back(BreakContinue());
526 // Visit the body region first. (This is basically the same as a while
527 // loop; see further comments in VisitWhileStmt.)
528 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
529 CountMap[S->getBody()] = BodyCount;
531 uint64_t BackedgeCount = CurrentCount;
532 BreakContinue BC = BreakContinueStack.pop_back_val();
534 // The increment is essentially part of the body but it needs to include
535 // the count for all the continue statements.
537 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
538 CountMap[S->getInc()] = IncCount;
542 // ...then go back and propagate counts through the condition.
544 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
546 CountMap[S->getCond()] = CondCount;
549 setCount(BC.BreakCount + CondCount - BodyCount);
550 RecordNextStmtCount = true;
553 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
557 Visit(S->getLoopVarStmt());
558 Visit(S->getRangeStmt());
559 Visit(S->getBeginStmt());
560 Visit(S->getEndStmt());
562 uint64_t ParentCount = CurrentCount;
563 BreakContinueStack.push_back(BreakContinue());
564 // Visit the body region first. (This is basically the same as a while
565 // loop; see further comments in VisitWhileStmt.)
566 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
567 CountMap[S->getBody()] = BodyCount;
569 uint64_t BackedgeCount = CurrentCount;
570 BreakContinue BC = BreakContinueStack.pop_back_val();
572 // The increment is essentially part of the body but it needs to include
573 // the count for all the continue statements.
574 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
575 CountMap[S->getInc()] = IncCount;
578 // ...then go back and propagate counts through the condition.
580 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
581 CountMap[S->getCond()] = CondCount;
583 setCount(BC.BreakCount + CondCount - BodyCount);
584 RecordNextStmtCount = true;
587 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
589 Visit(S->getElement());
590 uint64_t ParentCount = CurrentCount;
591 BreakContinueStack.push_back(BreakContinue());
592 // Counter tracks the body of the loop.
593 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
594 CountMap[S->getBody()] = BodyCount;
596 uint64_t BackedgeCount = CurrentCount;
597 BreakContinue BC = BreakContinueStack.pop_back_val();
599 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
601 RecordNextStmtCount = true;
604 void VisitSwitchStmt(const SwitchStmt *S) {
610 BreakContinueStack.push_back(BreakContinue());
612 // If the switch is inside a loop, add the continue counts.
613 BreakContinue BC = BreakContinueStack.pop_back_val();
614 if (!BreakContinueStack.empty())
615 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
616 // Counter tracks the exit block of the switch.
617 setCount(PGO.getRegionCount(S));
618 RecordNextStmtCount = true;
621 void VisitSwitchCase(const SwitchCase *S) {
622 RecordNextStmtCount = false;
623 // Counter for this particular case. This counts only jumps from the
624 // switch header and does not include fallthrough from the case before
626 uint64_t CaseCount = PGO.getRegionCount(S);
627 setCount(CurrentCount + CaseCount);
628 // We need the count without fallthrough in the mapping, so it's more useful
629 // for branch probabilities.
630 CountMap[S] = CaseCount;
631 RecordNextStmtCount = true;
632 Visit(S->getSubStmt());
635 void VisitIfStmt(const IfStmt *S) {
637 uint64_t ParentCount = CurrentCount;
642 // Counter tracks the "then" part of an if statement. The count for
643 // the "else" part, if it exists, will be calculated from this counter.
644 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
645 CountMap[S->getThen()] = ThenCount;
647 uint64_t OutCount = CurrentCount;
649 uint64_t ElseCount = ParentCount - ThenCount;
652 CountMap[S->getElse()] = ElseCount;
654 OutCount += CurrentCount;
656 OutCount += ElseCount;
658 RecordNextStmtCount = true;
661 void VisitCXXTryStmt(const CXXTryStmt *S) {
663 Visit(S->getTryBlock());
664 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
665 Visit(S->getHandler(I));
666 // Counter tracks the continuation block of the try statement.
667 setCount(PGO.getRegionCount(S));
668 RecordNextStmtCount = true;
671 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
672 RecordNextStmtCount = false;
673 // Counter tracks the catch statement's handler block.
674 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
675 CountMap[S] = CatchCount;
676 Visit(S->getHandlerBlock());
679 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
681 uint64_t ParentCount = CurrentCount;
684 // Counter tracks the "true" part of a conditional operator. The
685 // count in the "false" part will be calculated from this counter.
686 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
687 CountMap[E->getTrueExpr()] = TrueCount;
688 Visit(E->getTrueExpr());
689 uint64_t OutCount = CurrentCount;
691 uint64_t FalseCount = setCount(ParentCount - TrueCount);
692 CountMap[E->getFalseExpr()] = FalseCount;
693 Visit(E->getFalseExpr());
694 OutCount += CurrentCount;
697 RecordNextStmtCount = true;
700 void VisitBinLAnd(const BinaryOperator *E) {
702 uint64_t ParentCount = CurrentCount;
704 // Counter tracks the right hand side of a logical and operator.
705 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
706 CountMap[E->getRHS()] = RHSCount;
708 setCount(ParentCount + RHSCount - CurrentCount);
709 RecordNextStmtCount = true;
712 void VisitBinLOr(const BinaryOperator *E) {
714 uint64_t ParentCount = CurrentCount;
716 // Counter tracks the right hand side of a logical or operator.
717 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
718 CountMap[E->getRHS()] = RHSCount;
720 setCount(ParentCount + RHSCount - CurrentCount);
721 RecordNextStmtCount = true;
724 } // end anonymous namespace
726 void PGOHash::combine(HashType Type) {
727 // Check that we never combine 0 and only have six bits.
728 assert(Type && "Hash is invalid: unexpected type 0");
729 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
731 // Pass through MD5 if enough work has built up.
732 if (Count && Count % NumTypesPerWord == 0) {
733 using namespace llvm::support;
734 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
735 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
739 // Accumulate the current type.
741 Working = Working << NumBitsPerType | Type;
744 uint64_t PGOHash::finalize() {
745 // Use Working as the hash directly if we never used MD5.
746 if (Count <= NumTypesPerWord)
747 // No need to byte swap here, since none of the math was endian-dependent.
748 // This number will be byte-swapped as required on endianness transitions,
749 // so we will see the same value on the other side.
752 // Check for remaining work in Working.
754 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
755 // is buggy because it converts a uint64_t into an array of uint8_t.
756 if (HashVersion < PGO_HASH_V3) {
757 MD5.update({(uint8_t)Working});
759 using namespace llvm::support;
760 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
761 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
765 // Finalize the MD5 and return the hash.
766 llvm::MD5::MD5Result Result;
771 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
772 const Decl *D = GD.getDecl();
776 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
777 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
778 if (!InstrumentRegions && !PGOReader)
782 // Constructors and destructors may be represented by several functions in IR.
783 // If so, instrument only base variant, others are implemented by delegation
784 // to the base one, it would be counted twice otherwise.
785 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
786 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
787 if (GD.getCtorType() != Ctor_Base &&
788 CodeGenFunction::IsConstructorDelegationValid(CCD))
791 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
794 CGM.ClearUnusedCoverageMapping(D);
797 mapRegionCounters(D);
798 if (CGM.getCodeGenOpts().CoverageMapping)
799 emitCounterRegionMapping(D);
801 SourceManager &SM = CGM.getContext().getSourceManager();
802 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
803 computeRegionCounts(D);
804 applyFunctionAttributes(PGOReader, Fn);
808 void CodeGenPGO::mapRegionCounters(const Decl *D) {
809 // Use the latest hash version when inserting instrumentation, but use the
810 // version in the indexed profile if we're reading PGO data.
811 PGOHashVersion HashVersion = PGO_HASH_LATEST;
812 if (auto *PGOReader = CGM.getPGOReader())
813 HashVersion = getPGOHashVersion(PGOReader, CGM);
815 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
816 MapRegionCounters Walker(HashVersion, *RegionCounterMap);
817 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
818 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
819 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
820 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
821 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
822 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
823 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
824 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
825 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
826 NumRegionCounters = Walker.NextCounter;
827 FunctionHash = Walker.Hash.finalize();
830 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
834 // Don't map the functions in system headers.
835 const auto &SM = CGM.getContext().getSourceManager();
836 auto Loc = D->getBody()->getBeginLoc();
837 return SM.isInSystemHeader(Loc);
840 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
841 if (skipRegionMappingForDecl(D))
844 std::string CoverageMapping;
845 llvm::raw_string_ostream OS(CoverageMapping);
846 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
847 CGM.getContext().getSourceManager(),
848 CGM.getLangOpts(), RegionCounterMap.get());
849 MappingGen.emitCounterMapping(D, OS);
852 if (CoverageMapping.empty())
855 CGM.getCoverageMapping()->addFunctionMappingRecord(
856 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
860 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
861 llvm::GlobalValue::LinkageTypes Linkage) {
862 if (skipRegionMappingForDecl(D))
865 std::string CoverageMapping;
866 llvm::raw_string_ostream OS(CoverageMapping);
867 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
868 CGM.getContext().getSourceManager(),
870 MappingGen.emitEmptyMapping(D, OS);
873 if (CoverageMapping.empty())
876 setFuncName(Name, Linkage);
877 CGM.getCoverageMapping()->addFunctionMappingRecord(
878 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
881 void CodeGenPGO::computeRegionCounts(const Decl *D) {
882 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
883 ComputeRegionCounts Walker(*StmtCountMap, *this);
884 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
885 Walker.VisitFunctionDecl(FD);
886 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
887 Walker.VisitObjCMethodDecl(MD);
888 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
889 Walker.VisitBlockDecl(BD);
890 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
891 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
895 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
896 llvm::Function *Fn) {
897 if (!haveRegionCounts())
900 uint64_t FunctionCount = getRegionCount(nullptr);
901 Fn->setEntryCount(FunctionCount);
904 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
905 llvm::Value *StepV) {
906 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
908 if (!Builder.GetInsertBlock())
911 unsigned Counter = (*RegionCounterMap)[S];
912 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
914 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
915 Builder.getInt64(FunctionHash),
916 Builder.getInt32(NumRegionCounters),
917 Builder.getInt32(Counter), StepV};
919 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
920 makeArrayRef(Args, 4));
923 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
927 // This method either inserts a call to the profile run-time during
928 // instrumentation or puts profile data into metadata for PGO use.
929 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
930 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
932 if (!EnableValueProfiling)
935 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
938 if (isa<llvm::Constant>(ValuePtr))
941 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
942 if (InstrumentValueSites && RegionCounterMap) {
943 auto BuilderInsertPoint = Builder.saveIP();
944 Builder.SetInsertPoint(ValueSite);
945 llvm::Value *Args[5] = {
946 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
947 Builder.getInt64(FunctionHash),
948 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
949 Builder.getInt32(ValueKind),
950 Builder.getInt32(NumValueSites[ValueKind]++)
953 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
954 Builder.restoreIP(BuilderInsertPoint);
958 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
959 if (PGOReader && haveRegionCounts()) {
960 // We record the top most called three functions at each call site.
961 // Profile metadata contains "VP" string identifying this metadata
962 // as value profiling data, then a uint32_t value for the value profiling
963 // kind, a uint64_t value for the total number of times the call is
964 // executed, followed by the function hash and execution count (uint64_t)
965 // pairs for each function.
966 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
969 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
970 (llvm::InstrProfValueKind)ValueKind,
971 NumValueSites[ValueKind]);
973 NumValueSites[ValueKind]++;
977 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
979 CGM.getPGOStats().addVisited(IsInMainFile);
980 RegionCounts.clear();
981 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
982 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
983 if (auto E = RecordExpected.takeError()) {
984 auto IPE = llvm::InstrProfError::take(std::move(E));
985 if (IPE == llvm::instrprof_error::unknown_function)
986 CGM.getPGOStats().addMissing(IsInMainFile);
987 else if (IPE == llvm::instrprof_error::hash_mismatch)
988 CGM.getPGOStats().addMismatched(IsInMainFile);
989 else if (IPE == llvm::instrprof_error::malformed)
990 // TODO: Consider a more specific warning for this case.
991 CGM.getPGOStats().addMismatched(IsInMainFile);
995 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
996 RegionCounts = ProfRecord->Counts;
999 /// Calculate what to divide by to scale weights.
1001 /// Given the maximum weight, calculate a divisor that will scale all the
1002 /// weights to strictly less than UINT32_MAX.
1003 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1004 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1007 /// Scale an individual branch weight (and add 1).
1009 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1011 /// According to Laplace's Rule of Succession, it is better to compute the
1012 /// weight based on the count plus 1, so universally add 1 to the value.
1014 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1015 /// greater than \c Weight.
1016 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1017 assert(Scale && "scale by 0?");
1018 uint64_t Scaled = Weight / Scale + 1;
1019 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1023 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1024 uint64_t FalseCount) {
1025 // Check for empty weights.
1026 if (!TrueCount && !FalseCount)
1029 // Calculate how to scale down to 32-bits.
1030 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1032 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1033 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1034 scaleBranchWeight(FalseCount, Scale));
1038 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1039 // We need at least two elements to create meaningful weights.
1040 if (Weights.size() < 2)
1043 // Check for empty weights.
1044 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1048 // Calculate how to scale down to 32-bits.
1049 uint64_t Scale = calculateWeightScale(MaxWeight);
1051 SmallVector<uint32_t, 16> ScaledWeights;
1052 ScaledWeights.reserve(Weights.size());
1053 for (uint64_t W : Weights)
1054 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1056 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1057 return MDHelper.createBranchWeights(ScaledWeights);
1060 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1061 uint64_t LoopCount) {
1062 if (!PGO.haveRegionCounts())
1064 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1065 if (!CondCount || *CondCount == 0)
1067 return createProfileWeights(LoopCount,
1068 std::max(*CondCount, LoopCount) - LoopCount);