1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Instrumentation-based profile-guided optimization
12 //===----------------------------------------------------------------------===//
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
26 using namespace clang;
27 using namespace CodeGen;
29 void CodeGenPGO::setFuncName(StringRef Name,
30 llvm::GlobalValue::LinkageTypes Linkage) {
31 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
32 FuncName = llvm::getPGOFuncName(
33 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
34 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
36 // If we're generating a profile, create a variable for the name.
37 if (CGM.getCodeGenOpts().ProfileInstrGenerate)
38 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
41 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
42 setFuncName(Fn->getName(), Fn->getLinkage());
46 /// \brief Stable hasher for PGO region counters.
48 /// PGOHash produces a stable hash of a given function's control flow.
50 /// Changing the output of this hash will invalidate all previously generated
51 /// profiles -- i.e., don't do it.
53 /// \note When this hash does eventually change (years?), we still need to
54 /// support old hashes. We'll need to pull in the version number from the
55 /// profile data format and use the matching hash function.
61 static const int NumBitsPerType = 6;
62 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
63 static const unsigned TooBig = 1u << NumBitsPerType;
66 /// \brief Hash values for AST nodes.
68 /// Distinct values for AST nodes that have region counters attached.
70 /// These values must be stable. All new members must be added at the end,
71 /// and no members should be removed. Changing the enumeration value for an
72 /// AST node will affect the hash of every function that contains that node.
73 enum HashType : unsigned char {
80 ObjCForCollectionStmt,
90 BinaryConditionalOperator,
92 // Keep this last. It's for the static assert that follows.
95 static_assert(LastHashType <= TooBig, "Too many types in HashType");
97 // TODO: When this format changes, take in a version number here, and use the
98 // old hash calculation for file formats that used the old hash.
99 PGOHash() : Working(0), Count(0) {}
100 void combine(HashType Type);
103 const int PGOHash::NumBitsPerType;
104 const unsigned PGOHash::NumTypesPerWord;
105 const unsigned PGOHash::TooBig;
107 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
108 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
109 /// The next counter value to assign.
110 unsigned NextCounter;
111 /// The function hash.
113 /// The map of statements to counters.
114 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
116 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
117 : NextCounter(0), CounterMap(CounterMap) {}
119 // Blocks and lambdas are handled as separate functions, so we need not
120 // traverse them in the parent context.
121 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
122 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
123 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
125 bool VisitDecl(const Decl *D) {
126 switch (D->getKind()) {
130 case Decl::CXXMethod:
131 case Decl::CXXConstructor:
132 case Decl::CXXDestructor:
133 case Decl::CXXConversion:
134 case Decl::ObjCMethod:
137 CounterMap[D->getBody()] = NextCounter++;
143 bool VisitStmt(const Stmt *S) {
144 auto Type = getHashType(S);
145 if (Type == PGOHash::None)
148 CounterMap[S] = NextCounter++;
152 PGOHash::HashType getHashType(const Stmt *S) {
153 switch (S->getStmtClass()) {
156 case Stmt::LabelStmtClass:
157 return PGOHash::LabelStmt;
158 case Stmt::WhileStmtClass:
159 return PGOHash::WhileStmt;
160 case Stmt::DoStmtClass:
161 return PGOHash::DoStmt;
162 case Stmt::ForStmtClass:
163 return PGOHash::ForStmt;
164 case Stmt::CXXForRangeStmtClass:
165 return PGOHash::CXXForRangeStmt;
166 case Stmt::ObjCForCollectionStmtClass:
167 return PGOHash::ObjCForCollectionStmt;
168 case Stmt::SwitchStmtClass:
169 return PGOHash::SwitchStmt;
170 case Stmt::CaseStmtClass:
171 return PGOHash::CaseStmt;
172 case Stmt::DefaultStmtClass:
173 return PGOHash::DefaultStmt;
174 case Stmt::IfStmtClass:
175 return PGOHash::IfStmt;
176 case Stmt::CXXTryStmtClass:
177 return PGOHash::CXXTryStmt;
178 case Stmt::CXXCatchStmtClass:
179 return PGOHash::CXXCatchStmt;
180 case Stmt::ConditionalOperatorClass:
181 return PGOHash::ConditionalOperator;
182 case Stmt::BinaryConditionalOperatorClass:
183 return PGOHash::BinaryConditionalOperator;
184 case Stmt::BinaryOperatorClass: {
185 const BinaryOperator *BO = cast<BinaryOperator>(S);
186 if (BO->getOpcode() == BO_LAnd)
187 return PGOHash::BinaryOperatorLAnd;
188 if (BO->getOpcode() == BO_LOr)
189 return PGOHash::BinaryOperatorLOr;
193 return PGOHash::None;
197 /// A StmtVisitor that propagates the raw counts through the AST and
198 /// records the count at statements where the value may change.
199 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
203 /// A flag that is set when the current count should be recorded on the
204 /// next statement, such as at the exit of a loop.
205 bool RecordNextStmtCount;
207 /// The count at the current location in the traversal.
208 uint64_t CurrentCount;
210 /// The map of statements to count values.
211 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
213 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
214 struct BreakContinue {
216 uint64_t ContinueCount;
217 BreakContinue() : BreakCount(0), ContinueCount(0) {}
219 SmallVector<BreakContinue, 8> BreakContinueStack;
221 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
223 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
225 void RecordStmtCount(const Stmt *S) {
226 if (RecordNextStmtCount) {
227 CountMap[S] = CurrentCount;
228 RecordNextStmtCount = false;
232 /// Set and return the current count.
233 uint64_t setCount(uint64_t Count) {
234 CurrentCount = Count;
238 void VisitStmt(const Stmt *S) {
240 for (const Stmt *Child : S->children())
245 void VisitFunctionDecl(const FunctionDecl *D) {
246 // Counter tracks entry to the function body.
247 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
248 CountMap[D->getBody()] = BodyCount;
252 // Skip lambda expressions. We visit these as FunctionDecls when we're
253 // generating them and aren't interested in the body when generating a
255 void VisitLambdaExpr(const LambdaExpr *LE) {}
257 void VisitCapturedDecl(const CapturedDecl *D) {
258 // Counter tracks entry to the capture body.
259 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
260 CountMap[D->getBody()] = BodyCount;
264 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
265 // Counter tracks entry to the method body.
266 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
267 CountMap[D->getBody()] = BodyCount;
271 void VisitBlockDecl(const BlockDecl *D) {
272 // Counter tracks entry to the block body.
273 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
274 CountMap[D->getBody()] = BodyCount;
278 void VisitReturnStmt(const ReturnStmt *S) {
280 if (S->getRetValue())
281 Visit(S->getRetValue());
283 RecordNextStmtCount = true;
286 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
289 Visit(E->getSubExpr());
291 RecordNextStmtCount = true;
294 void VisitGotoStmt(const GotoStmt *S) {
297 RecordNextStmtCount = true;
300 void VisitLabelStmt(const LabelStmt *S) {
301 RecordNextStmtCount = false;
302 // Counter tracks the block following the label.
303 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
304 CountMap[S] = BlockCount;
305 Visit(S->getSubStmt());
308 void VisitBreakStmt(const BreakStmt *S) {
310 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
311 BreakContinueStack.back().BreakCount += CurrentCount;
313 RecordNextStmtCount = true;
316 void VisitContinueStmt(const ContinueStmt *S) {
318 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
319 BreakContinueStack.back().ContinueCount += CurrentCount;
321 RecordNextStmtCount = true;
324 void VisitWhileStmt(const WhileStmt *S) {
326 uint64_t ParentCount = CurrentCount;
328 BreakContinueStack.push_back(BreakContinue());
329 // Visit the body region first so the break/continue adjustments can be
330 // included when visiting the condition.
331 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
332 CountMap[S->getBody()] = CurrentCount;
334 uint64_t BackedgeCount = CurrentCount;
336 // ...then go back and propagate counts through the condition. The count
337 // at the start of the condition is the sum of the incoming edges,
338 // the backedge from the end of the loop body, and the edges from
339 // continue statements.
340 BreakContinue BC = BreakContinueStack.pop_back_val();
342 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
343 CountMap[S->getCond()] = CondCount;
345 setCount(BC.BreakCount + CondCount - BodyCount);
346 RecordNextStmtCount = true;
349 void VisitDoStmt(const DoStmt *S) {
351 uint64_t LoopCount = PGO.getRegionCount(S);
353 BreakContinueStack.push_back(BreakContinue());
354 // The count doesn't include the fallthrough from the parent scope. Add it.
355 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
356 CountMap[S->getBody()] = BodyCount;
358 uint64_t BackedgeCount = CurrentCount;
360 BreakContinue BC = BreakContinueStack.pop_back_val();
361 // The count at the start of the condition is equal to the count at the
362 // end of the body, plus any continues.
363 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
364 CountMap[S->getCond()] = CondCount;
366 setCount(BC.BreakCount + CondCount - LoopCount);
367 RecordNextStmtCount = true;
370 void VisitForStmt(const ForStmt *S) {
375 uint64_t ParentCount = CurrentCount;
377 BreakContinueStack.push_back(BreakContinue());
378 // Visit the body region first. (This is basically the same as a while
379 // loop; see further comments in VisitWhileStmt.)
380 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
381 CountMap[S->getBody()] = BodyCount;
383 uint64_t BackedgeCount = CurrentCount;
384 BreakContinue BC = BreakContinueStack.pop_back_val();
386 // The increment is essentially part of the body but it needs to include
387 // the count for all the continue statements.
389 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
390 CountMap[S->getInc()] = IncCount;
394 // ...then go back and propagate counts through the condition.
396 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
398 CountMap[S->getCond()] = CondCount;
401 setCount(BC.BreakCount + CondCount - BodyCount);
402 RecordNextStmtCount = true;
405 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
407 Visit(S->getLoopVarStmt());
408 Visit(S->getRangeStmt());
409 Visit(S->getBeginEndStmt());
411 uint64_t ParentCount = CurrentCount;
412 BreakContinueStack.push_back(BreakContinue());
413 // Visit the body region first. (This is basically the same as a while
414 // loop; see further comments in VisitWhileStmt.)
415 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
416 CountMap[S->getBody()] = BodyCount;
418 uint64_t BackedgeCount = CurrentCount;
419 BreakContinue BC = BreakContinueStack.pop_back_val();
421 // The increment is essentially part of the body but it needs to include
422 // the count for all the continue statements.
423 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
424 CountMap[S->getInc()] = IncCount;
427 // ...then go back and propagate counts through the condition.
429 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
430 CountMap[S->getCond()] = CondCount;
432 setCount(BC.BreakCount + CondCount - BodyCount);
433 RecordNextStmtCount = true;
436 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
438 Visit(S->getElement());
439 uint64_t ParentCount = CurrentCount;
440 BreakContinueStack.push_back(BreakContinue());
441 // Counter tracks the body of the loop.
442 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
443 CountMap[S->getBody()] = BodyCount;
445 uint64_t BackedgeCount = CurrentCount;
446 BreakContinue BC = BreakContinueStack.pop_back_val();
448 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
450 RecordNextStmtCount = true;
453 void VisitSwitchStmt(const SwitchStmt *S) {
457 BreakContinueStack.push_back(BreakContinue());
459 // If the switch is inside a loop, add the continue counts.
460 BreakContinue BC = BreakContinueStack.pop_back_val();
461 if (!BreakContinueStack.empty())
462 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
463 // Counter tracks the exit block of the switch.
464 setCount(PGO.getRegionCount(S));
465 RecordNextStmtCount = true;
468 void VisitSwitchCase(const SwitchCase *S) {
469 RecordNextStmtCount = false;
470 // Counter for this particular case. This counts only jumps from the
471 // switch header and does not include fallthrough from the case before
473 uint64_t CaseCount = PGO.getRegionCount(S);
474 setCount(CurrentCount + CaseCount);
475 // We need the count without fallthrough in the mapping, so it's more useful
476 // for branch probabilities.
477 CountMap[S] = CaseCount;
478 RecordNextStmtCount = true;
479 Visit(S->getSubStmt());
482 void VisitIfStmt(const IfStmt *S) {
484 uint64_t ParentCount = CurrentCount;
487 // Counter tracks the "then" part of an if statement. The count for
488 // the "else" part, if it exists, will be calculated from this counter.
489 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
490 CountMap[S->getThen()] = ThenCount;
492 uint64_t OutCount = CurrentCount;
494 uint64_t ElseCount = ParentCount - ThenCount;
497 CountMap[S->getElse()] = ElseCount;
499 OutCount += CurrentCount;
501 OutCount += ElseCount;
503 RecordNextStmtCount = true;
506 void VisitCXXTryStmt(const CXXTryStmt *S) {
508 Visit(S->getTryBlock());
509 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
510 Visit(S->getHandler(I));
511 // Counter tracks the continuation block of the try statement.
512 setCount(PGO.getRegionCount(S));
513 RecordNextStmtCount = true;
516 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
517 RecordNextStmtCount = false;
518 // Counter tracks the catch statement's handler block.
519 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
520 CountMap[S] = CatchCount;
521 Visit(S->getHandlerBlock());
524 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
526 uint64_t ParentCount = CurrentCount;
529 // Counter tracks the "true" part of a conditional operator. The
530 // count in the "false" part will be calculated from this counter.
531 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
532 CountMap[E->getTrueExpr()] = TrueCount;
533 Visit(E->getTrueExpr());
534 uint64_t OutCount = CurrentCount;
536 uint64_t FalseCount = setCount(ParentCount - TrueCount);
537 CountMap[E->getFalseExpr()] = FalseCount;
538 Visit(E->getFalseExpr());
539 OutCount += CurrentCount;
542 RecordNextStmtCount = true;
545 void VisitBinLAnd(const BinaryOperator *E) {
547 uint64_t ParentCount = CurrentCount;
549 // Counter tracks the right hand side of a logical and operator.
550 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
551 CountMap[E->getRHS()] = RHSCount;
553 setCount(ParentCount + RHSCount - CurrentCount);
554 RecordNextStmtCount = true;
557 void VisitBinLOr(const BinaryOperator *E) {
559 uint64_t ParentCount = CurrentCount;
561 // Counter tracks the right hand side of a logical or operator.
562 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
563 CountMap[E->getRHS()] = RHSCount;
565 setCount(ParentCount + RHSCount - CurrentCount);
566 RecordNextStmtCount = true;
569 } // end anonymous namespace
571 void PGOHash::combine(HashType Type) {
572 // Check that we never combine 0 and only have six bits.
573 assert(Type && "Hash is invalid: unexpected type 0");
574 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
576 // Pass through MD5 if enough work has built up.
577 if (Count && Count % NumTypesPerWord == 0) {
578 using namespace llvm::support;
579 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
580 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
584 // Accumulate the current type.
586 Working = Working << NumBitsPerType | Type;
589 uint64_t PGOHash::finalize() {
590 // Use Working as the hash directly if we never used MD5.
591 if (Count <= NumTypesPerWord)
592 // No need to byte swap here, since none of the math was endian-dependent.
593 // This number will be byte-swapped as required on endianness transitions,
594 // so we will see the same value on the other side.
597 // Check for remaining work in Working.
601 // Finalize the MD5 and return the hash.
602 llvm::MD5::MD5Result Result;
604 using namespace llvm::support;
605 return endian::read<uint64_t, little, unaligned>(Result);
608 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
609 const Decl *D = GD.getDecl();
610 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
611 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
612 if (!InstrumentRegions && !PGOReader)
616 // Constructors and destructors may be represented by several functions in IR.
617 // If so, instrument only base variant, others are implemented by delegation
618 // to the base one, it would be counted twice otherwise.
619 if (CGM.getTarget().getCXXABI().hasConstructorVariants() &&
620 ((isa<CXXConstructorDecl>(GD.getDecl()) &&
621 GD.getCtorType() != Ctor_Base) ||
622 (isa<CXXDestructorDecl>(GD.getDecl()) &&
623 GD.getDtorType() != Dtor_Base))) {
626 CGM.ClearUnusedCoverageMapping(D);
629 mapRegionCounters(D);
630 if (CGM.getCodeGenOpts().CoverageMapping)
631 emitCounterRegionMapping(D);
633 SourceManager &SM = CGM.getContext().getSourceManager();
634 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
635 computeRegionCounts(D);
636 applyFunctionAttributes(PGOReader, Fn);
640 void CodeGenPGO::mapRegionCounters(const Decl *D) {
641 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
642 MapRegionCounters Walker(*RegionCounterMap);
643 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
644 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
645 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
646 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
647 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
648 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
649 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
650 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
651 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
652 NumRegionCounters = Walker.NextCounter;
653 FunctionHash = Walker.Hash.finalize();
656 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
657 if (SkipCoverageMapping)
659 // Don't map the functions inside the system headers
660 auto Loc = D->getBody()->getLocStart();
661 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
664 std::string CoverageMapping;
665 llvm::raw_string_ostream OS(CoverageMapping);
666 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
667 CGM.getContext().getSourceManager(),
668 CGM.getLangOpts(), RegionCounterMap.get());
669 MappingGen.emitCounterMapping(D, OS);
672 if (CoverageMapping.empty())
675 CGM.getCoverageMapping()->addFunctionMappingRecord(
676 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
680 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
681 llvm::GlobalValue::LinkageTypes Linkage) {
682 if (SkipCoverageMapping)
684 // Don't map the functions inside the system headers
685 auto Loc = D->getBody()->getLocStart();
686 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
689 std::string CoverageMapping;
690 llvm::raw_string_ostream OS(CoverageMapping);
691 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
692 CGM.getContext().getSourceManager(),
694 MappingGen.emitEmptyMapping(D, OS);
697 if (CoverageMapping.empty())
700 setFuncName(Name, Linkage);
701 CGM.getCoverageMapping()->addFunctionMappingRecord(
702 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
705 void CodeGenPGO::computeRegionCounts(const Decl *D) {
706 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
707 ComputeRegionCounts Walker(*StmtCountMap, *this);
708 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
709 Walker.VisitFunctionDecl(FD);
710 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
711 Walker.VisitObjCMethodDecl(MD);
712 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
713 Walker.VisitBlockDecl(BD);
714 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
715 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
719 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
720 llvm::Function *Fn) {
721 if (!haveRegionCounts())
724 uint64_t FunctionCount = getRegionCount(nullptr);
725 Fn->setEntryCount(FunctionCount);
728 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
729 if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
731 if (!Builder.GetInsertBlock())
734 unsigned Counter = (*RegionCounterMap)[S];
735 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
736 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
737 {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
738 Builder.getInt64(FunctionHash),
739 Builder.getInt32(NumRegionCounters),
740 Builder.getInt32(Counter)});
743 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
745 CGM.getPGOStats().addVisited(IsInMainFile);
746 RegionCounts.clear();
747 if (std::error_code EC =
748 PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
749 if (EC == llvm::instrprof_error::unknown_function)
750 CGM.getPGOStats().addMissing(IsInMainFile);
751 else if (EC == llvm::instrprof_error::hash_mismatch)
752 CGM.getPGOStats().addMismatched(IsInMainFile);
753 else if (EC == llvm::instrprof_error::malformed)
754 // TODO: Consider a more specific warning for this case.
755 CGM.getPGOStats().addMismatched(IsInMainFile);
756 RegionCounts.clear();
760 /// \brief Calculate what to divide by to scale weights.
762 /// Given the maximum weight, calculate a divisor that will scale all the
763 /// weights to strictly less than UINT32_MAX.
764 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
765 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
768 /// \brief Scale an individual branch weight (and add 1).
770 /// Scale a 64-bit weight down to 32-bits using \c Scale.
772 /// According to Laplace's Rule of Succession, it is better to compute the
773 /// weight based on the count plus 1, so universally add 1 to the value.
775 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
776 /// greater than \c Weight.
777 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
778 assert(Scale && "scale by 0?");
779 uint64_t Scaled = Weight / Scale + 1;
780 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
784 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
785 uint64_t FalseCount) {
786 // Check for empty weights.
787 if (!TrueCount && !FalseCount)
790 // Calculate how to scale down to 32-bits.
791 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
793 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
794 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
795 scaleBranchWeight(FalseCount, Scale));
799 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
800 // We need at least two elements to create meaningful weights.
801 if (Weights.size() < 2)
804 // Check for empty weights.
805 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
809 // Calculate how to scale down to 32-bits.
810 uint64_t Scale = calculateWeightScale(MaxWeight);
812 SmallVector<uint32_t, 16> ScaledWeights;
813 ScaledWeights.reserve(Weights.size());
814 for (uint64_t W : Weights)
815 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
817 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
818 return MDHelper.createBranchWeights(ScaledWeights);
821 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
822 uint64_t LoopCount) {
823 if (!PGO.haveRegionCounts())
825 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
826 assert(CondCount.hasValue() && "missing expected loop condition count");
829 return createProfileWeights(LoopCount,
830 std::max(*CondCount, LoopCount) - LoopCount);