1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Instrumentation-based profile-guided optimization
12 //===----------------------------------------------------------------------===//
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
25 static llvm::cl::opt<bool> EnableValueProfiling(
26 "enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"), llvm::cl::init(false));
29 using namespace clang;
30 using namespace CodeGen;
32 void CodeGenPGO::setFuncName(StringRef Name,
33 llvm::GlobalValue::LinkageTypes Linkage) {
34 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
35 FuncName = llvm::getPGOFuncName(
36 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
37 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 // If we're generating a profile, create a variable for the name.
40 if (CGM.getCodeGenOpts().hasProfileClangInstr())
41 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
45 setFuncName(Fn->getName(), Fn->getLinkage());
46 // Create PGOFuncName meta data.
47 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
51 /// \brief Stable hasher for PGO region counters.
53 /// PGOHash produces a stable hash of a given function's control flow.
55 /// Changing the output of this hash will invalidate all previously generated
56 /// profiles -- i.e., don't do it.
58 /// \note When this hash does eventually change (years?), we still need to
59 /// support old hashes. We'll need to pull in the version number from the
60 /// profile data format and use the matching hash function.
66 static const int NumBitsPerType = 6;
67 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
68 static const unsigned TooBig = 1u << NumBitsPerType;
71 /// \brief Hash values for AST nodes.
73 /// Distinct values for AST nodes that have region counters attached.
75 /// These values must be stable. All new members must be added at the end,
76 /// and no members should be removed. Changing the enumeration value for an
77 /// AST node will affect the hash of every function that contains that node.
78 enum HashType : unsigned char {
85 ObjCForCollectionStmt,
95 BinaryConditionalOperator,
97 // Keep this last. It's for the static assert that follows.
100 static_assert(LastHashType <= TooBig, "Too many types in HashType");
102 // TODO: When this format changes, take in a version number here, and use the
103 // old hash calculation for file formats that used the old hash.
104 PGOHash() : Working(0), Count(0) {}
105 void combine(HashType Type);
108 const int PGOHash::NumBitsPerType;
109 const unsigned PGOHash::NumTypesPerWord;
110 const unsigned PGOHash::TooBig;
112 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
113 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
114 /// The next counter value to assign.
115 unsigned NextCounter;
116 /// The function hash.
118 /// The map of statements to counters.
119 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
121 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
122 : NextCounter(0), CounterMap(CounterMap) {}
124 // Blocks and lambdas are handled as separate functions, so we need not
125 // traverse them in the parent context.
126 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
127 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
128 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
130 bool VisitDecl(const Decl *D) {
131 switch (D->getKind()) {
135 case Decl::CXXMethod:
136 case Decl::CXXConstructor:
137 case Decl::CXXDestructor:
138 case Decl::CXXConversion:
139 case Decl::ObjCMethod:
142 CounterMap[D->getBody()] = NextCounter++;
148 bool VisitStmt(const Stmt *S) {
149 auto Type = getHashType(S);
150 if (Type == PGOHash::None)
153 CounterMap[S] = NextCounter++;
157 PGOHash::HashType getHashType(const Stmt *S) {
158 switch (S->getStmtClass()) {
161 case Stmt::LabelStmtClass:
162 return PGOHash::LabelStmt;
163 case Stmt::WhileStmtClass:
164 return PGOHash::WhileStmt;
165 case Stmt::DoStmtClass:
166 return PGOHash::DoStmt;
167 case Stmt::ForStmtClass:
168 return PGOHash::ForStmt;
169 case Stmt::CXXForRangeStmtClass:
170 return PGOHash::CXXForRangeStmt;
171 case Stmt::ObjCForCollectionStmtClass:
172 return PGOHash::ObjCForCollectionStmt;
173 case Stmt::SwitchStmtClass:
174 return PGOHash::SwitchStmt;
175 case Stmt::CaseStmtClass:
176 return PGOHash::CaseStmt;
177 case Stmt::DefaultStmtClass:
178 return PGOHash::DefaultStmt;
179 case Stmt::IfStmtClass:
180 return PGOHash::IfStmt;
181 case Stmt::CXXTryStmtClass:
182 return PGOHash::CXXTryStmt;
183 case Stmt::CXXCatchStmtClass:
184 return PGOHash::CXXCatchStmt;
185 case Stmt::ConditionalOperatorClass:
186 return PGOHash::ConditionalOperator;
187 case Stmt::BinaryConditionalOperatorClass:
188 return PGOHash::BinaryConditionalOperator;
189 case Stmt::BinaryOperatorClass: {
190 const BinaryOperator *BO = cast<BinaryOperator>(S);
191 if (BO->getOpcode() == BO_LAnd)
192 return PGOHash::BinaryOperatorLAnd;
193 if (BO->getOpcode() == BO_LOr)
194 return PGOHash::BinaryOperatorLOr;
198 return PGOHash::None;
202 /// A StmtVisitor that propagates the raw counts through the AST and
203 /// records the count at statements where the value may change.
204 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
208 /// A flag that is set when the current count should be recorded on the
209 /// next statement, such as at the exit of a loop.
210 bool RecordNextStmtCount;
212 /// The count at the current location in the traversal.
213 uint64_t CurrentCount;
215 /// The map of statements to count values.
216 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
218 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
219 struct BreakContinue {
221 uint64_t ContinueCount;
222 BreakContinue() : BreakCount(0), ContinueCount(0) {}
224 SmallVector<BreakContinue, 8> BreakContinueStack;
226 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
228 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
230 void RecordStmtCount(const Stmt *S) {
231 if (RecordNextStmtCount) {
232 CountMap[S] = CurrentCount;
233 RecordNextStmtCount = false;
237 /// Set and return the current count.
238 uint64_t setCount(uint64_t Count) {
239 CurrentCount = Count;
243 void VisitStmt(const Stmt *S) {
245 for (const Stmt *Child : S->children())
250 void VisitFunctionDecl(const FunctionDecl *D) {
251 // Counter tracks entry to the function body.
252 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
253 CountMap[D->getBody()] = BodyCount;
257 // Skip lambda expressions. We visit these as FunctionDecls when we're
258 // generating them and aren't interested in the body when generating a
260 void VisitLambdaExpr(const LambdaExpr *LE) {}
262 void VisitCapturedDecl(const CapturedDecl *D) {
263 // Counter tracks entry to the capture body.
264 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
265 CountMap[D->getBody()] = BodyCount;
269 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
270 // Counter tracks entry to the method body.
271 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
272 CountMap[D->getBody()] = BodyCount;
276 void VisitBlockDecl(const BlockDecl *D) {
277 // Counter tracks entry to the block body.
278 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
279 CountMap[D->getBody()] = BodyCount;
283 void VisitReturnStmt(const ReturnStmt *S) {
285 if (S->getRetValue())
286 Visit(S->getRetValue());
288 RecordNextStmtCount = true;
291 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
294 Visit(E->getSubExpr());
296 RecordNextStmtCount = true;
299 void VisitGotoStmt(const GotoStmt *S) {
302 RecordNextStmtCount = true;
305 void VisitLabelStmt(const LabelStmt *S) {
306 RecordNextStmtCount = false;
307 // Counter tracks the block following the label.
308 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
309 CountMap[S] = BlockCount;
310 Visit(S->getSubStmt());
313 void VisitBreakStmt(const BreakStmt *S) {
315 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
316 BreakContinueStack.back().BreakCount += CurrentCount;
318 RecordNextStmtCount = true;
321 void VisitContinueStmt(const ContinueStmt *S) {
323 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
324 BreakContinueStack.back().ContinueCount += CurrentCount;
326 RecordNextStmtCount = true;
329 void VisitWhileStmt(const WhileStmt *S) {
331 uint64_t ParentCount = CurrentCount;
333 BreakContinueStack.push_back(BreakContinue());
334 // Visit the body region first so the break/continue adjustments can be
335 // included when visiting the condition.
336 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
337 CountMap[S->getBody()] = CurrentCount;
339 uint64_t BackedgeCount = CurrentCount;
341 // ...then go back and propagate counts through the condition. The count
342 // at the start of the condition is the sum of the incoming edges,
343 // the backedge from the end of the loop body, and the edges from
344 // continue statements.
345 BreakContinue BC = BreakContinueStack.pop_back_val();
347 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
348 CountMap[S->getCond()] = CondCount;
350 setCount(BC.BreakCount + CondCount - BodyCount);
351 RecordNextStmtCount = true;
354 void VisitDoStmt(const DoStmt *S) {
356 uint64_t LoopCount = PGO.getRegionCount(S);
358 BreakContinueStack.push_back(BreakContinue());
359 // The count doesn't include the fallthrough from the parent scope. Add it.
360 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
361 CountMap[S->getBody()] = BodyCount;
363 uint64_t BackedgeCount = CurrentCount;
365 BreakContinue BC = BreakContinueStack.pop_back_val();
366 // The count at the start of the condition is equal to the count at the
367 // end of the body, plus any continues.
368 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
369 CountMap[S->getCond()] = CondCount;
371 setCount(BC.BreakCount + CondCount - LoopCount);
372 RecordNextStmtCount = true;
375 void VisitForStmt(const ForStmt *S) {
380 uint64_t ParentCount = CurrentCount;
382 BreakContinueStack.push_back(BreakContinue());
383 // Visit the body region first. (This is basically the same as a while
384 // loop; see further comments in VisitWhileStmt.)
385 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
386 CountMap[S->getBody()] = BodyCount;
388 uint64_t BackedgeCount = CurrentCount;
389 BreakContinue BC = BreakContinueStack.pop_back_val();
391 // The increment is essentially part of the body but it needs to include
392 // the count for all the continue statements.
394 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
395 CountMap[S->getInc()] = IncCount;
399 // ...then go back and propagate counts through the condition.
401 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
403 CountMap[S->getCond()] = CondCount;
406 setCount(BC.BreakCount + CondCount - BodyCount);
407 RecordNextStmtCount = true;
410 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
412 Visit(S->getLoopVarStmt());
413 Visit(S->getRangeStmt());
414 Visit(S->getBeginStmt());
415 Visit(S->getEndStmt());
417 uint64_t ParentCount = CurrentCount;
418 BreakContinueStack.push_back(BreakContinue());
419 // Visit the body region first. (This is basically the same as a while
420 // loop; see further comments in VisitWhileStmt.)
421 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
422 CountMap[S->getBody()] = BodyCount;
424 uint64_t BackedgeCount = CurrentCount;
425 BreakContinue BC = BreakContinueStack.pop_back_val();
427 // The increment is essentially part of the body but it needs to include
428 // the count for all the continue statements.
429 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
430 CountMap[S->getInc()] = IncCount;
433 // ...then go back and propagate counts through the condition.
435 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
436 CountMap[S->getCond()] = CondCount;
438 setCount(BC.BreakCount + CondCount - BodyCount);
439 RecordNextStmtCount = true;
442 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
444 Visit(S->getElement());
445 uint64_t ParentCount = CurrentCount;
446 BreakContinueStack.push_back(BreakContinue());
447 // Counter tracks the body of the loop.
448 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
449 CountMap[S->getBody()] = BodyCount;
451 uint64_t BackedgeCount = CurrentCount;
452 BreakContinue BC = BreakContinueStack.pop_back_val();
454 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
456 RecordNextStmtCount = true;
459 void VisitSwitchStmt(const SwitchStmt *S) {
465 BreakContinueStack.push_back(BreakContinue());
467 // If the switch is inside a loop, add the continue counts.
468 BreakContinue BC = BreakContinueStack.pop_back_val();
469 if (!BreakContinueStack.empty())
470 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
471 // Counter tracks the exit block of the switch.
472 setCount(PGO.getRegionCount(S));
473 RecordNextStmtCount = true;
476 void VisitSwitchCase(const SwitchCase *S) {
477 RecordNextStmtCount = false;
478 // Counter for this particular case. This counts only jumps from the
479 // switch header and does not include fallthrough from the case before
481 uint64_t CaseCount = PGO.getRegionCount(S);
482 setCount(CurrentCount + CaseCount);
483 // We need the count without fallthrough in the mapping, so it's more useful
484 // for branch probabilities.
485 CountMap[S] = CaseCount;
486 RecordNextStmtCount = true;
487 Visit(S->getSubStmt());
490 void VisitIfStmt(const IfStmt *S) {
492 uint64_t ParentCount = CurrentCount;
497 // Counter tracks the "then" part of an if statement. The count for
498 // the "else" part, if it exists, will be calculated from this counter.
499 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
500 CountMap[S->getThen()] = ThenCount;
502 uint64_t OutCount = CurrentCount;
504 uint64_t ElseCount = ParentCount - ThenCount;
507 CountMap[S->getElse()] = ElseCount;
509 OutCount += CurrentCount;
511 OutCount += ElseCount;
513 RecordNextStmtCount = true;
516 void VisitCXXTryStmt(const CXXTryStmt *S) {
518 Visit(S->getTryBlock());
519 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
520 Visit(S->getHandler(I));
521 // Counter tracks the continuation block of the try statement.
522 setCount(PGO.getRegionCount(S));
523 RecordNextStmtCount = true;
526 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
527 RecordNextStmtCount = false;
528 // Counter tracks the catch statement's handler block.
529 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
530 CountMap[S] = CatchCount;
531 Visit(S->getHandlerBlock());
534 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
536 uint64_t ParentCount = CurrentCount;
539 // Counter tracks the "true" part of a conditional operator. The
540 // count in the "false" part will be calculated from this counter.
541 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
542 CountMap[E->getTrueExpr()] = TrueCount;
543 Visit(E->getTrueExpr());
544 uint64_t OutCount = CurrentCount;
546 uint64_t FalseCount = setCount(ParentCount - TrueCount);
547 CountMap[E->getFalseExpr()] = FalseCount;
548 Visit(E->getFalseExpr());
549 OutCount += CurrentCount;
552 RecordNextStmtCount = true;
555 void VisitBinLAnd(const BinaryOperator *E) {
557 uint64_t ParentCount = CurrentCount;
559 // Counter tracks the right hand side of a logical and operator.
560 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
561 CountMap[E->getRHS()] = RHSCount;
563 setCount(ParentCount + RHSCount - CurrentCount);
564 RecordNextStmtCount = true;
567 void VisitBinLOr(const BinaryOperator *E) {
569 uint64_t ParentCount = CurrentCount;
571 // Counter tracks the right hand side of a logical or operator.
572 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
573 CountMap[E->getRHS()] = RHSCount;
575 setCount(ParentCount + RHSCount - CurrentCount);
576 RecordNextStmtCount = true;
579 } // end anonymous namespace
581 void PGOHash::combine(HashType Type) {
582 // Check that we never combine 0 and only have six bits.
583 assert(Type && "Hash is invalid: unexpected type 0");
584 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
586 // Pass through MD5 if enough work has built up.
587 if (Count && Count % NumTypesPerWord == 0) {
588 using namespace llvm::support;
589 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
590 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
594 // Accumulate the current type.
596 Working = Working << NumBitsPerType | Type;
599 uint64_t PGOHash::finalize() {
600 // Use Working as the hash directly if we never used MD5.
601 if (Count <= NumTypesPerWord)
602 // No need to byte swap here, since none of the math was endian-dependent.
603 // This number will be byte-swapped as required on endianness transitions,
604 // so we will see the same value on the other side.
607 // Check for remaining work in Working.
611 // Finalize the MD5 and return the hash.
612 llvm::MD5::MD5Result Result;
614 using namespace llvm::support;
618 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
619 const Decl *D = GD.getDecl();
620 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
621 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
622 if (!InstrumentRegions && !PGOReader)
626 // Constructors and destructors may be represented by several functions in IR.
627 // If so, instrument only base variant, others are implemented by delegation
628 // to the base one, it would be counted twice otherwise.
629 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
630 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
633 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
634 if (GD.getCtorType() != Ctor_Base &&
635 CodeGenFunction::IsConstructorDelegationValid(CCD))
638 CGM.ClearUnusedCoverageMapping(D);
641 mapRegionCounters(D);
642 if (CGM.getCodeGenOpts().CoverageMapping)
643 emitCounterRegionMapping(D);
645 SourceManager &SM = CGM.getContext().getSourceManager();
646 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
647 computeRegionCounts(D);
648 applyFunctionAttributes(PGOReader, Fn);
652 void CodeGenPGO::mapRegionCounters(const Decl *D) {
653 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
654 MapRegionCounters Walker(*RegionCounterMap);
655 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
656 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
657 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
658 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
659 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
660 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
661 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
662 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
663 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
664 NumRegionCounters = Walker.NextCounter;
665 FunctionHash = Walker.Hash.finalize();
668 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
672 // Don't map the functions in system headers.
673 const auto &SM = CGM.getContext().getSourceManager();
674 auto Loc = D->getBody()->getLocStart();
675 return SM.isInSystemHeader(Loc);
678 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
679 if (skipRegionMappingForDecl(D))
682 std::string CoverageMapping;
683 llvm::raw_string_ostream OS(CoverageMapping);
684 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
685 CGM.getContext().getSourceManager(),
686 CGM.getLangOpts(), RegionCounterMap.get());
687 MappingGen.emitCounterMapping(D, OS);
690 if (CoverageMapping.empty())
693 CGM.getCoverageMapping()->addFunctionMappingRecord(
694 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
698 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
699 llvm::GlobalValue::LinkageTypes Linkage) {
700 if (skipRegionMappingForDecl(D))
703 std::string CoverageMapping;
704 llvm::raw_string_ostream OS(CoverageMapping);
705 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
706 CGM.getContext().getSourceManager(),
708 MappingGen.emitEmptyMapping(D, OS);
711 if (CoverageMapping.empty())
714 setFuncName(Name, Linkage);
715 CGM.getCoverageMapping()->addFunctionMappingRecord(
716 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
719 void CodeGenPGO::computeRegionCounts(const Decl *D) {
720 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
721 ComputeRegionCounts Walker(*StmtCountMap, *this);
722 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
723 Walker.VisitFunctionDecl(FD);
724 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
725 Walker.VisitObjCMethodDecl(MD);
726 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
727 Walker.VisitBlockDecl(BD);
728 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
729 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
733 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
734 llvm::Function *Fn) {
735 if (!haveRegionCounts())
738 uint64_t FunctionCount = getRegionCount(nullptr);
739 Fn->setEntryCount(FunctionCount);
742 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
743 llvm::Value *StepV) {
744 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
746 if (!Builder.GetInsertBlock())
749 unsigned Counter = (*RegionCounterMap)[S];
750 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
752 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
753 Builder.getInt64(FunctionHash),
754 Builder.getInt32(NumRegionCounters),
755 Builder.getInt32(Counter), StepV};
757 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
758 makeArrayRef(Args, 4));
761 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
765 // This method either inserts a call to the profile run-time during
766 // instrumentation or puts profile data into metadata for PGO use.
767 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
768 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
770 if (!EnableValueProfiling)
773 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
776 if (isa<llvm::Constant>(ValuePtr))
779 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
780 if (InstrumentValueSites && RegionCounterMap) {
781 auto BuilderInsertPoint = Builder.saveIP();
782 Builder.SetInsertPoint(ValueSite);
783 llvm::Value *Args[5] = {
784 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
785 Builder.getInt64(FunctionHash),
786 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
787 Builder.getInt32(ValueKind),
788 Builder.getInt32(NumValueSites[ValueKind]++)
791 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
792 Builder.restoreIP(BuilderInsertPoint);
796 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
797 if (PGOReader && haveRegionCounts()) {
798 // We record the top most called three functions at each call site.
799 // Profile metadata contains "VP" string identifying this metadata
800 // as value profiling data, then a uint32_t value for the value profiling
801 // kind, a uint64_t value for the total number of times the call is
802 // executed, followed by the function hash and execution count (uint64_t)
803 // pairs for each function.
804 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
807 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
808 (llvm::InstrProfValueKind)ValueKind,
809 NumValueSites[ValueKind]);
811 NumValueSites[ValueKind]++;
815 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
817 CGM.getPGOStats().addVisited(IsInMainFile);
818 RegionCounts.clear();
819 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
820 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
821 if (auto E = RecordExpected.takeError()) {
822 auto IPE = llvm::InstrProfError::take(std::move(E));
823 if (IPE == llvm::instrprof_error::unknown_function)
824 CGM.getPGOStats().addMissing(IsInMainFile);
825 else if (IPE == llvm::instrprof_error::hash_mismatch)
826 CGM.getPGOStats().addMismatched(IsInMainFile);
827 else if (IPE == llvm::instrprof_error::malformed)
828 // TODO: Consider a more specific warning for this case.
829 CGM.getPGOStats().addMismatched(IsInMainFile);
833 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
834 RegionCounts = ProfRecord->Counts;
837 /// \brief Calculate what to divide by to scale weights.
839 /// Given the maximum weight, calculate a divisor that will scale all the
840 /// weights to strictly less than UINT32_MAX.
841 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
842 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
845 /// \brief Scale an individual branch weight (and add 1).
847 /// Scale a 64-bit weight down to 32-bits using \c Scale.
849 /// According to Laplace's Rule of Succession, it is better to compute the
850 /// weight based on the count plus 1, so universally add 1 to the value.
852 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
853 /// greater than \c Weight.
854 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
855 assert(Scale && "scale by 0?");
856 uint64_t Scaled = Weight / Scale + 1;
857 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
861 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
862 uint64_t FalseCount) {
863 // Check for empty weights.
864 if (!TrueCount && !FalseCount)
867 // Calculate how to scale down to 32-bits.
868 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
870 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
871 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
872 scaleBranchWeight(FalseCount, Scale));
876 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
877 // We need at least two elements to create meaningful weights.
878 if (Weights.size() < 2)
881 // Check for empty weights.
882 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
886 // Calculate how to scale down to 32-bits.
887 uint64_t Scale = calculateWeightScale(MaxWeight);
889 SmallVector<uint32_t, 16> ScaledWeights;
890 ScaledWeights.reserve(Weights.size());
891 for (uint64_t W : Weights)
892 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
894 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
895 return MDHelper.createBranchWeights(ScaledWeights);
898 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
899 uint64_t LoopCount) {
900 if (!PGO.haveRegionCounts())
902 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
903 assert(CondCount.hasValue() && "missing expected loop condition count");
906 return createProfileWeights(LoopCount,
907 std::max(*CondCount, LoopCount) - LoopCount);