1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Instrumentation-based profile-guided optimization
12 //===----------------------------------------------------------------------===//
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
26 using namespace clang;
27 using namespace CodeGen;
29 void CodeGenPGO::setFuncName(StringRef Name,
30 llvm::GlobalValue::LinkageTypes Linkage) {
31 StringRef RawFuncName = Name;
33 // Function names may be prefixed with a binary '1' to indicate
34 // that the backend should not modify the symbols due to any platform
35 // naming convention. Do not include that '1' in the PGO profile name.
36 if (RawFuncName[0] == '\1')
37 RawFuncName = RawFuncName.substr(1);
39 FuncName = RawFuncName;
40 if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41 // For local symbols, prepend the main file name to distinguish them.
42 // Do not include the full path in the file name since there's no guarantee
43 // that it will stay the same, e.g., if the files are checked out from
44 // version control in different locations.
45 if (CGM.getCodeGenOpts().MainFileName.empty())
46 FuncName = FuncName.insert(0, "<unknown>:");
48 FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
51 // If we're generating a profile, create a variable for the name.
52 if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53 createFuncNameVar(Linkage);
56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57 setFuncName(Fn->getName(), Fn->getLinkage());
60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61 // Usually, we want to match the function's linkage, but
62 // available_externally and extern_weak both have the wrong semantics.
63 if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
64 Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
65 else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
66 Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
69 llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
71 new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
72 Value, "__llvm_profile_name_" + FuncName);
74 // Hide the symbol so that we correctly get a copy for each executable.
75 if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
76 FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
80 /// \brief Stable hasher for PGO region counters.
82 /// PGOHash produces a stable hash of a given function's control flow.
84 /// Changing the output of this hash will invalidate all previously generated
85 /// profiles -- i.e., don't do it.
87 /// \note When this hash does eventually change (years?), we still need to
88 /// support old hashes. We'll need to pull in the version number from the
89 /// profile data format and use the matching hash function.
95 static const int NumBitsPerType = 6;
96 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
97 static const unsigned TooBig = 1u << NumBitsPerType;
100 /// \brief Hash values for AST nodes.
102 /// Distinct values for AST nodes that have region counters attached.
104 /// These values must be stable. All new members must be added at the end,
105 /// and no members should be removed. Changing the enumeration value for an
106 /// AST node will affect the hash of every function that contains that node.
107 enum HashType : unsigned char {
114 ObjCForCollectionStmt,
124 BinaryConditionalOperator,
126 // Keep this last. It's for the static assert that follows.
129 static_assert(LastHashType <= TooBig, "Too many types in HashType");
131 // TODO: When this format changes, take in a version number here, and use the
132 // old hash calculation for file formats that used the old hash.
133 PGOHash() : Working(0), Count(0) {}
134 void combine(HashType Type);
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
141 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
142 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
143 /// The next counter value to assign.
144 unsigned NextCounter;
145 /// The function hash.
147 /// The map of statements to counters.
148 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
150 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
151 : NextCounter(0), CounterMap(CounterMap) {}
153 // Blocks and lambdas are handled as separate functions, so we need not
154 // traverse them in the parent context.
155 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
156 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
157 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
159 bool VisitDecl(const Decl *D) {
160 switch (D->getKind()) {
164 case Decl::CXXMethod:
165 case Decl::CXXConstructor:
166 case Decl::CXXDestructor:
167 case Decl::CXXConversion:
168 case Decl::ObjCMethod:
171 CounterMap[D->getBody()] = NextCounter++;
177 bool VisitStmt(const Stmt *S) {
178 auto Type = getHashType(S);
179 if (Type == PGOHash::None)
182 CounterMap[S] = NextCounter++;
186 PGOHash::HashType getHashType(const Stmt *S) {
187 switch (S->getStmtClass()) {
190 case Stmt::LabelStmtClass:
191 return PGOHash::LabelStmt;
192 case Stmt::WhileStmtClass:
193 return PGOHash::WhileStmt;
194 case Stmt::DoStmtClass:
195 return PGOHash::DoStmt;
196 case Stmt::ForStmtClass:
197 return PGOHash::ForStmt;
198 case Stmt::CXXForRangeStmtClass:
199 return PGOHash::CXXForRangeStmt;
200 case Stmt::ObjCForCollectionStmtClass:
201 return PGOHash::ObjCForCollectionStmt;
202 case Stmt::SwitchStmtClass:
203 return PGOHash::SwitchStmt;
204 case Stmt::CaseStmtClass:
205 return PGOHash::CaseStmt;
206 case Stmt::DefaultStmtClass:
207 return PGOHash::DefaultStmt;
208 case Stmt::IfStmtClass:
209 return PGOHash::IfStmt;
210 case Stmt::CXXTryStmtClass:
211 return PGOHash::CXXTryStmt;
212 case Stmt::CXXCatchStmtClass:
213 return PGOHash::CXXCatchStmt;
214 case Stmt::ConditionalOperatorClass:
215 return PGOHash::ConditionalOperator;
216 case Stmt::BinaryConditionalOperatorClass:
217 return PGOHash::BinaryConditionalOperator;
218 case Stmt::BinaryOperatorClass: {
219 const BinaryOperator *BO = cast<BinaryOperator>(S);
220 if (BO->getOpcode() == BO_LAnd)
221 return PGOHash::BinaryOperatorLAnd;
222 if (BO->getOpcode() == BO_LOr)
223 return PGOHash::BinaryOperatorLOr;
227 return PGOHash::None;
231 /// A StmtVisitor that propagates the raw counts through the AST and
232 /// records the count at statements where the value may change.
233 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
237 /// A flag that is set when the current count should be recorded on the
238 /// next statement, such as at the exit of a loop.
239 bool RecordNextStmtCount;
241 /// The map of statements to count values.
242 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
244 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
245 struct BreakContinue {
247 uint64_t ContinueCount;
248 BreakContinue() : BreakCount(0), ContinueCount(0) {}
250 SmallVector<BreakContinue, 8> BreakContinueStack;
252 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
254 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
256 void RecordStmtCount(const Stmt *S) {
257 if (RecordNextStmtCount) {
258 CountMap[S] = PGO.getCurrentRegionCount();
259 RecordNextStmtCount = false;
263 void VisitStmt(const Stmt *S) {
265 for (Stmt::const_child_range I = S->children(); I; ++I) {
271 void VisitFunctionDecl(const FunctionDecl *D) {
272 // Counter tracks entry to the function body.
273 RegionCounter Cnt(PGO, D->getBody());
275 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
279 // Skip lambda expressions. We visit these as FunctionDecls when we're
280 // generating them and aren't interested in the body when generating a
282 void VisitLambdaExpr(const LambdaExpr *LE) {}
284 void VisitCapturedDecl(const CapturedDecl *D) {
285 // Counter tracks entry to the capture body.
286 RegionCounter Cnt(PGO, D->getBody());
288 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
292 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
293 // Counter tracks entry to the method body.
294 RegionCounter Cnt(PGO, D->getBody());
296 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
300 void VisitBlockDecl(const BlockDecl *D) {
301 // Counter tracks entry to the block body.
302 RegionCounter Cnt(PGO, D->getBody());
304 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
308 void VisitReturnStmt(const ReturnStmt *S) {
310 if (S->getRetValue())
311 Visit(S->getRetValue());
312 PGO.setCurrentRegionUnreachable();
313 RecordNextStmtCount = true;
316 void VisitGotoStmt(const GotoStmt *S) {
318 PGO.setCurrentRegionUnreachable();
319 RecordNextStmtCount = true;
322 void VisitLabelStmt(const LabelStmt *S) {
323 RecordNextStmtCount = false;
324 // Counter tracks the block following the label.
325 RegionCounter Cnt(PGO, S);
327 CountMap[S] = PGO.getCurrentRegionCount();
328 Visit(S->getSubStmt());
331 void VisitBreakStmt(const BreakStmt *S) {
333 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
334 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
335 PGO.setCurrentRegionUnreachable();
336 RecordNextStmtCount = true;
339 void VisitContinueStmt(const ContinueStmt *S) {
341 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
342 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
343 PGO.setCurrentRegionUnreachable();
344 RecordNextStmtCount = true;
347 void VisitWhileStmt(const WhileStmt *S) {
349 // Counter tracks the body of the loop.
350 RegionCounter Cnt(PGO, S);
351 BreakContinueStack.push_back(BreakContinue());
352 // Visit the body region first so the break/continue adjustments can be
353 // included when visiting the condition.
355 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
357 Cnt.adjustForControlFlow();
359 // ...then go back and propagate counts through the condition. The count
360 // at the start of the condition is the sum of the incoming edges,
361 // the backedge from the end of the loop body, and the edges from
362 // continue statements.
363 BreakContinue BC = BreakContinueStack.pop_back_val();
364 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
365 Cnt.getAdjustedCount() + BC.ContinueCount);
366 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
368 Cnt.adjustForControlFlow();
369 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
370 RecordNextStmtCount = true;
373 void VisitDoStmt(const DoStmt *S) {
375 // Counter tracks the body of the loop.
376 RegionCounter Cnt(PGO, S);
377 BreakContinueStack.push_back(BreakContinue());
378 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
379 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
381 Cnt.adjustForControlFlow();
383 BreakContinue BC = BreakContinueStack.pop_back_val();
384 // The count at the start of the condition is equal to the count at the
385 // end of the body. The adjusted count does not include either the
386 // fall-through count coming into the loop or the continue count, so add
387 // both of those separately. This is coincidentally the same equation as
388 // with while loops but for different reasons.
389 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
390 Cnt.getAdjustedCount() + BC.ContinueCount);
391 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
393 Cnt.adjustForControlFlow();
394 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
395 RecordNextStmtCount = true;
398 void VisitForStmt(const ForStmt *S) {
402 // Counter tracks the body of the loop.
403 RegionCounter Cnt(PGO, S);
404 BreakContinueStack.push_back(BreakContinue());
405 // Visit the body region first. (This is basically the same as a while
406 // loop; see further comments in VisitWhileStmt.)
408 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
410 Cnt.adjustForControlFlow();
412 // The increment is essentially part of the body but it needs to include
413 // the count for all the continue statements.
415 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
416 BreakContinueStack.back().ContinueCount);
417 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
419 Cnt.adjustForControlFlow();
422 BreakContinue BC = BreakContinueStack.pop_back_val();
424 // ...then go back and propagate counts through the condition.
426 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
427 Cnt.getAdjustedCount() +
429 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
431 Cnt.adjustForControlFlow();
433 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
434 RecordNextStmtCount = true;
437 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
439 Visit(S->getRangeStmt());
440 Visit(S->getBeginEndStmt());
441 // Counter tracks the body of the loop.
442 RegionCounter Cnt(PGO, S);
443 BreakContinueStack.push_back(BreakContinue());
444 // Visit the body region first. (This is basically the same as a while
445 // loop; see further comments in VisitWhileStmt.)
447 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
448 Visit(S->getLoopVarStmt());
450 Cnt.adjustForControlFlow();
452 // The increment is essentially part of the body but it needs to include
453 // the count for all the continue statements.
454 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
455 BreakContinueStack.back().ContinueCount);
456 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
458 Cnt.adjustForControlFlow();
460 BreakContinue BC = BreakContinueStack.pop_back_val();
462 // ...then go back and propagate counts through the condition.
463 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
464 Cnt.getAdjustedCount() +
466 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
468 Cnt.adjustForControlFlow();
469 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
470 RecordNextStmtCount = true;
473 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
475 Visit(S->getElement());
476 // Counter tracks the body of the loop.
477 RegionCounter Cnt(PGO, S);
478 BreakContinueStack.push_back(BreakContinue());
480 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
482 BreakContinue BC = BreakContinueStack.pop_back_val();
483 Cnt.adjustForControlFlow();
484 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
485 RecordNextStmtCount = true;
488 void VisitSwitchStmt(const SwitchStmt *S) {
491 PGO.setCurrentRegionUnreachable();
492 BreakContinueStack.push_back(BreakContinue());
494 // If the switch is inside a loop, add the continue counts.
495 BreakContinue BC = BreakContinueStack.pop_back_val();
496 if (!BreakContinueStack.empty())
497 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
498 // Counter tracks the exit block of the switch.
499 RegionCounter ExitCnt(PGO, S);
500 ExitCnt.beginRegion();
501 RecordNextStmtCount = true;
504 void VisitCaseStmt(const CaseStmt *S) {
505 RecordNextStmtCount = false;
506 // Counter for this particular case. This counts only jumps from the
507 // switch header and does not include fallthrough from the case before
509 RegionCounter Cnt(PGO, S);
510 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
511 CountMap[S] = Cnt.getCount();
512 RecordNextStmtCount = true;
513 Visit(S->getSubStmt());
516 void VisitDefaultStmt(const DefaultStmt *S) {
517 RecordNextStmtCount = false;
518 // Counter for this default case. This does not include fallthrough from
519 // the previous case.
520 RegionCounter Cnt(PGO, S);
521 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
522 CountMap[S] = Cnt.getCount();
523 RecordNextStmtCount = true;
524 Visit(S->getSubStmt());
527 void VisitIfStmt(const IfStmt *S) {
529 // Counter tracks the "then" part of an if statement. The count for
530 // the "else" part, if it exists, will be calculated from this counter.
531 RegionCounter Cnt(PGO, S);
535 CountMap[S->getThen()] = PGO.getCurrentRegionCount();
537 Cnt.adjustForControlFlow();
540 Cnt.beginElseRegion();
541 CountMap[S->getElse()] = PGO.getCurrentRegionCount();
543 Cnt.adjustForControlFlow();
545 Cnt.applyAdjustmentsToRegion(0);
546 RecordNextStmtCount = true;
549 void VisitCXXTryStmt(const CXXTryStmt *S) {
551 Visit(S->getTryBlock());
552 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
553 Visit(S->getHandler(I));
554 // Counter tracks the continuation block of the try statement.
555 RegionCounter Cnt(PGO, S);
557 RecordNextStmtCount = true;
560 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
561 RecordNextStmtCount = false;
562 // Counter tracks the catch statement's handler block.
563 RegionCounter Cnt(PGO, S);
565 CountMap[S] = PGO.getCurrentRegionCount();
566 Visit(S->getHandlerBlock());
569 void VisitAbstractConditionalOperator(
570 const AbstractConditionalOperator *E) {
572 // Counter tracks the "true" part of a conditional operator. The
573 // count in the "false" part will be calculated from this counter.
574 RegionCounter Cnt(PGO, E);
578 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
579 Visit(E->getTrueExpr());
580 Cnt.adjustForControlFlow();
582 Cnt.beginElseRegion();
583 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
584 Visit(E->getFalseExpr());
585 Cnt.adjustForControlFlow();
587 Cnt.applyAdjustmentsToRegion(0);
588 RecordNextStmtCount = true;
591 void VisitBinLAnd(const BinaryOperator *E) {
593 // Counter tracks the right hand side of a logical and operator.
594 RegionCounter Cnt(PGO, E);
597 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
599 Cnt.adjustForControlFlow();
600 Cnt.applyAdjustmentsToRegion(0);
601 RecordNextStmtCount = true;
604 void VisitBinLOr(const BinaryOperator *E) {
606 // Counter tracks the right hand side of a logical or operator.
607 RegionCounter Cnt(PGO, E);
610 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
612 Cnt.adjustForControlFlow();
613 Cnt.applyAdjustmentsToRegion(0);
614 RecordNextStmtCount = true;
619 void PGOHash::combine(HashType Type) {
620 // Check that we never combine 0 and only have six bits.
621 assert(Type && "Hash is invalid: unexpected type 0");
622 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
624 // Pass through MD5 if enough work has built up.
625 if (Count && Count % NumTypesPerWord == 0) {
626 using namespace llvm::support;
627 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
628 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
632 // Accumulate the current type.
634 Working = Working << NumBitsPerType | Type;
637 uint64_t PGOHash::finalize() {
638 // Use Working as the hash directly if we never used MD5.
639 if (Count <= NumTypesPerWord)
640 // No need to byte swap here, since none of the math was endian-dependent.
641 // This number will be byte-swapped as required on endianness transitions,
642 // so we will see the same value on the other side.
645 // Check for remaining work in Working.
649 // Finalize the MD5 and return the hash.
650 llvm::MD5::MD5Result Result;
652 using namespace llvm::support;
653 return endian::read<uint64_t, little, unaligned>(Result);
656 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
657 // Make sure we only emit coverage mapping for one constructor/destructor.
658 // Clang emits several functions for the constructor and the destructor of
659 // a class. Every function is instrumented, but we only want to provide
660 // coverage for one of them. Because of that we only emit the coverage mapping
661 // for the base constructor/destructor.
662 if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
663 GD.getCtorType() != Ctor_Base) ||
664 (isa<CXXDestructorDecl>(GD.getDecl()) &&
665 GD.getDtorType() != Dtor_Base)) {
666 SkipCoverageMapping = true;
670 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
671 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
672 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
673 if (!InstrumentRegions && !PGOReader)
677 CGM.ClearUnusedCoverageMapping(D);
680 mapRegionCounters(D);
681 if (CGM.getCodeGenOpts().CoverageMapping)
682 emitCounterRegionMapping(D);
684 SourceManager &SM = CGM.getContext().getSourceManager();
685 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
686 computeRegionCounts(D);
687 applyFunctionAttributes(PGOReader, Fn);
691 void CodeGenPGO::mapRegionCounters(const Decl *D) {
692 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
693 MapRegionCounters Walker(*RegionCounterMap);
694 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
695 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
696 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
697 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
698 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
699 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
700 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
701 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
702 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
703 NumRegionCounters = Walker.NextCounter;
704 FunctionHash = Walker.Hash.finalize();
707 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
708 if (SkipCoverageMapping)
710 // Don't map the functions inside the system headers
711 auto Loc = D->getBody()->getLocStart();
712 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
715 std::string CoverageMapping;
716 llvm::raw_string_ostream OS(CoverageMapping);
717 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
718 CGM.getContext().getSourceManager(),
719 CGM.getLangOpts(), RegionCounterMap.get());
720 MappingGen.emitCounterMapping(D, OS);
723 if (CoverageMapping.empty())
726 CGM.getCoverageMapping()->addFunctionMappingRecord(
727 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
731 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
732 llvm::GlobalValue::LinkageTypes Linkage) {
733 if (SkipCoverageMapping)
735 setFuncName(FuncName, Linkage);
737 // Don't map the functions inside the system headers
738 auto Loc = D->getBody()->getLocStart();
739 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
742 std::string CoverageMapping;
743 llvm::raw_string_ostream OS(CoverageMapping);
744 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
745 CGM.getContext().getSourceManager(),
747 MappingGen.emitEmptyMapping(D, OS);
750 if (CoverageMapping.empty())
753 CGM.getCoverageMapping()->addFunctionMappingRecord(
754 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
757 void CodeGenPGO::computeRegionCounts(const Decl *D) {
758 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
759 ComputeRegionCounts Walker(*StmtCountMap, *this);
760 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
761 Walker.VisitFunctionDecl(FD);
762 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
763 Walker.VisitObjCMethodDecl(MD);
764 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
765 Walker.VisitBlockDecl(BD);
766 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
767 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
771 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
772 llvm::Function *Fn) {
773 if (!haveRegionCounts())
776 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
777 uint64_t FunctionCount = getRegionCount(0);
778 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
779 // Turn on InlineHint attribute for hot functions.
780 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
781 Fn->addFnAttr(llvm::Attribute::InlineHint);
782 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
783 // Turn on Cold attribute for cold functions.
784 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
785 Fn->addFnAttr(llvm::Attribute::Cold);
788 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
789 if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
791 if (!Builder.GetInsertPoint())
793 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
794 Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
795 llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
796 Builder.getInt64(FunctionHash),
797 Builder.getInt32(NumRegionCounters),
798 Builder.getInt32(Counter));
801 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
803 CGM.getPGOStats().addVisited(IsInMainFile);
804 RegionCounts.clear();
805 if (std::error_code EC =
806 PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
807 if (EC == llvm::instrprof_error::unknown_function)
808 CGM.getPGOStats().addMissing(IsInMainFile);
809 else if (EC == llvm::instrprof_error::hash_mismatch)
810 CGM.getPGOStats().addMismatched(IsInMainFile);
811 else if (EC == llvm::instrprof_error::malformed)
812 // TODO: Consider a more specific warning for this case.
813 CGM.getPGOStats().addMismatched(IsInMainFile);
814 RegionCounts.clear();
818 /// \brief Calculate what to divide by to scale weights.
820 /// Given the maximum weight, calculate a divisor that will scale all the
821 /// weights to strictly less than UINT32_MAX.
822 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
823 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
826 /// \brief Scale an individual branch weight (and add 1).
828 /// Scale a 64-bit weight down to 32-bits using \c Scale.
830 /// According to Laplace's Rule of Succession, it is better to compute the
831 /// weight based on the count plus 1, so universally add 1 to the value.
833 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
834 /// greater than \c Weight.
835 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
836 assert(Scale && "scale by 0?");
837 uint64_t Scaled = Weight / Scale + 1;
838 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
842 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
843 uint64_t FalseCount) {
844 // Check for empty weights.
845 if (!TrueCount && !FalseCount)
848 // Calculate how to scale down to 32-bits.
849 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
851 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
852 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
853 scaleBranchWeight(FalseCount, Scale));
856 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
857 // We need at least two elements to create meaningful weights.
858 if (Weights.size() < 2)
861 // Check for empty weights.
862 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
866 // Calculate how to scale down to 32-bits.
867 uint64_t Scale = calculateWeightScale(MaxWeight);
869 SmallVector<uint32_t, 16> ScaledWeights;
870 ScaledWeights.reserve(Weights.size());
871 for (uint64_t W : Weights)
872 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
874 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
875 return MDHelper.createBranchWeights(ScaledWeights);
878 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
879 RegionCounter &Cnt) {
880 if (!haveRegionCounts())
882 uint64_t LoopCount = Cnt.getCount();
883 uint64_t CondCount = 0;
884 bool Found = getStmtCount(Cond, CondCount);
885 assert(Found && "missing expected loop condition count");
889 return createBranchWeights(LoopCount,
890 std::max(CondCount, LoopCount) - LoopCount);