]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
MFV r311899:
[FreeBSD/FreeBSD.git] / contrib / llvm / lib / Transforms / Instrumentation / PGOInstrumentation.cpp
1 //===-- PGOInstrumentation.cpp - MST-based PGO Instrumentation ------------===//
2 //
3 //                      The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements PGO instrumentation using a minimum spanning tree based
11 // on the following paper:
12 //   [1] Donald E. Knuth, Francis R. Stevenson. Optimal measurement of points
13 //   for program frequency counts. BIT Numerical Mathematics 1973, Volume 13,
14 //   Issue 3, pp 313-322
15 // The idea of the algorithm based on the fact that for each node (except for
16 // the entry and exit), the sum of incoming edge counts equals the sum of
17 // outgoing edge counts. The count of edge on spanning tree can be derived from
18 // those edges not on the spanning tree. Knuth proves this method instruments
19 // the minimum number of edges.
20 //
21 // The minimal spanning tree here is actually a maximum weight tree -- on-tree
22 // edges have higher frequencies (more likely to execute). The idea is to
23 // instrument those less frequently executed edges to reduce the runtime
24 // overhead of instrumented binaries.
25 //
26 // This file contains two passes:
27 // (1) Pass PGOInstrumentationGen which instruments the IR to generate edge
28 // count profile, and generates the instrumentation for indirect call
29 // profiling.
30 // (2) Pass PGOInstrumentationUse which reads the edge count profile and
31 // annotates the branch weights. It also reads the indirect call value
32 // profiling records and annotate the indirect call instructions.
33 //
34 // To get the precise counter information, These two passes need to invoke at
35 // the same compilation point (so they see the same IR). For pass
36 // PGOInstrumentationGen, the real work is done in instrumentOneFunc(). For
37 // pass PGOInstrumentationUse, the real work in done in class PGOUseFunc and
38 // the profile is opened in module level and passed to each PGOUseFunc instance.
39 // The shared code for PGOInstrumentationGen and PGOInstrumentationUse is put
40 // in class FuncPGOInstrumentation.
41 //
42 // Class PGOEdge represents a CFG edge and some auxiliary information. Class
43 // BBInfo contains auxiliary information for each BB. These two classes are used
44 // in pass PGOInstrumentationGen. Class PGOUseEdge and UseBBInfo are the derived
45 // class of PGOEdge and BBInfo, respectively. They contains extra data structure
46 // used in populating profile counters.
47 // The MST implementation is in Class CFGMST (CFGMST.h).
48 //
49 //===----------------------------------------------------------------------===//
50
51 #include "llvm/Transforms/PGOInstrumentation.h"
52 #include "CFGMST.h"
53 #include "llvm/ADT/STLExtras.h"
54 #include "llvm/ADT/Statistic.h"
55 #include "llvm/ADT/Triple.h"
56 #include "llvm/Analysis/BlockFrequencyInfo.h"
57 #include "llvm/Analysis/BranchProbabilityInfo.h"
58 #include "llvm/Analysis/CFG.h"
59 #include "llvm/Analysis/IndirectCallSiteVisitor.h"
60 #include "llvm/IR/CallSite.h"
61 #include "llvm/IR/DiagnosticInfo.h"
62 #include "llvm/IR/IRBuilder.h"
63 #include "llvm/IR/InstIterator.h"
64 #include "llvm/IR/Instructions.h"
65 #include "llvm/IR/IntrinsicInst.h"
66 #include "llvm/IR/MDBuilder.h"
67 #include "llvm/IR/Module.h"
68 #include "llvm/Pass.h"
69 #include "llvm/ProfileData/InstrProfReader.h"
70 #include "llvm/ProfileData/ProfileCommon.h"
71 #include "llvm/Support/BranchProbability.h"
72 #include "llvm/Support/Debug.h"
73 #include "llvm/Support/JamCRC.h"
74 #include "llvm/Transforms/Instrumentation.h"
75 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
76 #include <algorithm>
77 #include <string>
78 #include <utility>
79 #include <vector>
80
81 using namespace llvm;
82
83 #define DEBUG_TYPE "pgo-instrumentation"
84
85 STATISTIC(NumOfPGOInstrument, "Number of edges instrumented.");
86 STATISTIC(NumOfPGOEdge, "Number of edges.");
87 STATISTIC(NumOfPGOBB, "Number of basic-blocks.");
88 STATISTIC(NumOfPGOSplit, "Number of critical edge splits.");
89 STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts.");
90 STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile.");
91 STATISTIC(NumOfPGOMissing, "Number of functions without profile.");
92 STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations.");
93
94 // Command line option to specify the file to read profile from. This is
95 // mainly used for testing.
96 static cl::opt<std::string>
97     PGOTestProfileFile("pgo-test-profile-file", cl::init(""), cl::Hidden,
98                        cl::value_desc("filename"),
99                        cl::desc("Specify the path of profile data file. This is"
100                                 "mainly for test purpose."));
101
102 // Command line option to disable value profiling. The default is false:
103 // i.e. value profiling is enabled by default. This is for debug purpose.
104 static cl::opt<bool> DisableValueProfiling("disable-vp", cl::init(false),
105                                            cl::Hidden,
106                                            cl::desc("Disable Value Profiling"));
107
108 // Command line option to set the maximum number of VP annotations to write to
109 // the metadata for a single indirect call callsite.
110 static cl::opt<unsigned> MaxNumAnnotations(
111     "icp-max-annotations", cl::init(3), cl::Hidden, cl::ZeroOrMore,
112     cl::desc("Max number of annotations for a single indirect "
113              "call callsite"));
114
115 // Command line option to enable/disable the warning about missing profile
116 // information.
117 static cl::opt<bool> NoPGOWarnMissing("no-pgo-warn-missing", cl::init(false),
118                                       cl::Hidden);
119
120 // Command line option to enable/disable the warning about a hash mismatch in
121 // the profile data.
122 static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false),
123                                        cl::Hidden);
124
125 namespace {
126 class PGOInstrumentationGenLegacyPass : public ModulePass {
127 public:
128   static char ID;
129
130   PGOInstrumentationGenLegacyPass() : ModulePass(ID) {
131     initializePGOInstrumentationGenLegacyPassPass(
132         *PassRegistry::getPassRegistry());
133   }
134
135   const char *getPassName() const override {
136     return "PGOInstrumentationGenPass";
137   }
138
139 private:
140   bool runOnModule(Module &M) override;
141
142   void getAnalysisUsage(AnalysisUsage &AU) const override {
143     AU.addRequired<BlockFrequencyInfoWrapperPass>();
144   }
145 };
146
147 class PGOInstrumentationUseLegacyPass : public ModulePass {
148 public:
149   static char ID;
150
151   // Provide the profile filename as the parameter.
152   PGOInstrumentationUseLegacyPass(std::string Filename = "")
153       : ModulePass(ID), ProfileFileName(std::move(Filename)) {
154     if (!PGOTestProfileFile.empty())
155       ProfileFileName = PGOTestProfileFile;
156     initializePGOInstrumentationUseLegacyPassPass(
157         *PassRegistry::getPassRegistry());
158   }
159
160   const char *getPassName() const override {
161     return "PGOInstrumentationUsePass";
162   }
163
164 private:
165   std::string ProfileFileName;
166
167   bool runOnModule(Module &M) override;
168   void getAnalysisUsage(AnalysisUsage &AU) const override {
169     AU.addRequired<BlockFrequencyInfoWrapperPass>();
170   }
171 };
172 } // end anonymous namespace
173
174 char PGOInstrumentationGenLegacyPass::ID = 0;
175 INITIALIZE_PASS_BEGIN(PGOInstrumentationGenLegacyPass, "pgo-instr-gen",
176                       "PGO instrumentation.", false, false)
177 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
178 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
179 INITIALIZE_PASS_END(PGOInstrumentationGenLegacyPass, "pgo-instr-gen",
180                     "PGO instrumentation.", false, false)
181
182 ModulePass *llvm::createPGOInstrumentationGenLegacyPass() {
183   return new PGOInstrumentationGenLegacyPass();
184 }
185
186 char PGOInstrumentationUseLegacyPass::ID = 0;
187 INITIALIZE_PASS_BEGIN(PGOInstrumentationUseLegacyPass, "pgo-instr-use",
188                       "Read PGO instrumentation profile.", false, false)
189 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
190 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
191 INITIALIZE_PASS_END(PGOInstrumentationUseLegacyPass, "pgo-instr-use",
192                     "Read PGO instrumentation profile.", false, false)
193
194 ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename) {
195   return new PGOInstrumentationUseLegacyPass(Filename.str());
196 }
197
198 namespace {
199 /// \brief An MST based instrumentation for PGO
200 ///
201 /// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO
202 /// in the function level.
203 struct PGOEdge {
204   // This class implements the CFG edges. Note the CFG can be a multi-graph.
205   // So there might be multiple edges with same SrcBB and DestBB.
206   const BasicBlock *SrcBB;
207   const BasicBlock *DestBB;
208   uint64_t Weight;
209   bool InMST;
210   bool Removed;
211   bool IsCritical;
212   PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
213       : SrcBB(Src), DestBB(Dest), Weight(W), InMST(false), Removed(false),
214         IsCritical(false) {}
215   // Return the information string of an edge.
216   const std::string infoString() const {
217     return (Twine(Removed ? "-" : " ") + (InMST ? " " : "*") +
218             (IsCritical ? "c" : " ") + "  W=" + Twine(Weight)).str();
219   }
220 };
221
222 // This class stores the auxiliary information for each BB.
223 struct BBInfo {
224   BBInfo *Group;
225   uint32_t Index;
226   uint32_t Rank;
227
228   BBInfo(unsigned IX) : Group(this), Index(IX), Rank(0) {}
229
230   // Return the information string of this object.
231   const std::string infoString() const {
232     return (Twine("Index=") + Twine(Index)).str();
233   }
234 };
235
236 // This class implements the CFG edges. Note the CFG can be a multi-graph.
237 template <class Edge, class BBInfo> class FuncPGOInstrumentation {
238 private:
239   Function &F;
240   void computeCFGHash();
241
242 public:
243   std::string FuncName;
244   GlobalVariable *FuncNameVar;
245   // CFG hash value for this function.
246   uint64_t FunctionHash;
247
248   // The Minimum Spanning Tree of function CFG.
249   CFGMST<Edge, BBInfo> MST;
250
251   // Give an edge, find the BB that will be instrumented.
252   // Return nullptr if there is no BB to be instrumented.
253   BasicBlock *getInstrBB(Edge *E);
254
255   // Return the auxiliary BB information.
256   BBInfo &getBBInfo(const BasicBlock *BB) const { return MST.getBBInfo(BB); }
257
258   // Dump edges and BB information.
259   void dumpInfo(std::string Str = "") const {
260     MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " +
261                               Twine(FunctionHash) + "\t" + Str);
262   }
263
264   FuncPGOInstrumentation(Function &Func, bool CreateGlobalVar = false,
265                          BranchProbabilityInfo *BPI = nullptr,
266                          BlockFrequencyInfo *BFI = nullptr)
267       : F(Func), FunctionHash(0), MST(F, BPI, BFI) {
268     FuncName = getPGOFuncName(F);
269     computeCFGHash();
270     DEBUG(dumpInfo("after CFGMST"));
271
272     NumOfPGOBB += MST.BBInfos.size();
273     for (auto &E : MST.AllEdges) {
274       if (E->Removed)
275         continue;
276       NumOfPGOEdge++;
277       if (!E->InMST)
278         NumOfPGOInstrument++;
279     }
280
281     if (CreateGlobalVar)
282       FuncNameVar = createPGOFuncNameVar(F, FuncName);
283   }
284 };
285
286 // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index
287 // value of each BB in the CFG. The higher 32 bits record the number of edges.
288 template <class Edge, class BBInfo>
289 void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
290   std::vector<char> Indexes;
291   JamCRC JC;
292   for (auto &BB : F) {
293     const TerminatorInst *TI = BB.getTerminator();
294     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
295       BasicBlock *Succ = TI->getSuccessor(I);
296       uint32_t Index = getBBInfo(Succ).Index;
297       for (int J = 0; J < 4; J++)
298         Indexes.push_back((char)(Index >> (J * 8)));
299     }
300   }
301   JC.update(Indexes);
302   FunctionHash = (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC();
303 }
304
305 // Given a CFG E to be instrumented, find which BB to place the instrumented
306 // code. The function will split the critical edge if necessary.
307 template <class Edge, class BBInfo>
308 BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) {
309   if (E->InMST || E->Removed)
310     return nullptr;
311
312   BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
313   BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
314   // For a fake edge, instrument the real BB.
315   if (SrcBB == nullptr)
316     return DestBB;
317   if (DestBB == nullptr)
318     return SrcBB;
319
320   // Instrument the SrcBB if it has a single successor,
321   // otherwise, the DestBB if this is not a critical edge.
322   TerminatorInst *TI = SrcBB->getTerminator();
323   if (TI->getNumSuccessors() <= 1)
324     return SrcBB;
325   if (!E->IsCritical)
326     return DestBB;
327
328   // For a critical edge, we have to split. Instrument the newly
329   // created BB.
330   NumOfPGOSplit++;
331   DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index << " --> "
332                << getBBInfo(DestBB).Index << "\n");
333   unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
334   BasicBlock *InstrBB = SplitCriticalEdge(TI, SuccNum);
335   assert(InstrBB && "Critical edge is not split");
336
337   E->Removed = true;
338   return InstrBB;
339 }
340
341 // Visit all edge and instrument the edges not in MST, and do value profiling.
342 // Critical edges will be split.
343 static void instrumentOneFunc(Function &F, Module *M,
344                               BranchProbabilityInfo *BPI,
345                               BlockFrequencyInfo *BFI) {
346   unsigned NumCounters = 0;
347   FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, true, BPI, BFI);
348   for (auto &E : FuncInfo.MST.AllEdges) {
349     if (!E->InMST && !E->Removed)
350       NumCounters++;
351   }
352
353   uint32_t I = 0;
354   Type *I8PtrTy = Type::getInt8PtrTy(M->getContext());
355   for (auto &E : FuncInfo.MST.AllEdges) {
356     BasicBlock *InstrBB = FuncInfo.getInstrBB(E.get());
357     if (!InstrBB)
358       continue;
359
360     IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt());
361     assert(Builder.GetInsertPoint() != InstrBB->end() &&
362            "Cannot get the Instrumentation point");
363     Builder.CreateCall(
364         Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment),
365         {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
366          Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters),
367          Builder.getInt32(I++)});
368   }
369
370   if (DisableValueProfiling)
371     return;
372
373   unsigned NumIndirectCallSites = 0;
374   for (auto &I : findIndirectCallSites(F)) {
375     CallSite CS(I);
376     Value *Callee = CS.getCalledValue();
377     DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = "
378                  << NumIndirectCallSites << "\n");
379     IRBuilder<> Builder(I);
380     assert(Builder.GetInsertPoint() != I->getParent()->end() &&
381            "Cannot get the Instrumentation point");
382     Builder.CreateCall(
383         Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
384         {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
385          Builder.getInt64(FuncInfo.FunctionHash),
386          Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()),
387          Builder.getInt32(llvm::InstrProfValueKind::IPVK_IndirectCallTarget),
388          Builder.getInt32(NumIndirectCallSites++)});
389   }
390   NumOfPGOICall += NumIndirectCallSites;
391 }
392
393 // This class represents a CFG edge in profile use compilation.
394 struct PGOUseEdge : public PGOEdge {
395   bool CountValid;
396   uint64_t CountValue;
397   PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
398       : PGOEdge(Src, Dest, W), CountValid(false), CountValue(0) {}
399
400   // Set edge count value
401   void setEdgeCount(uint64_t Value) {
402     CountValue = Value;
403     CountValid = true;
404   }
405
406   // Return the information string for this object.
407   const std::string infoString() const {
408     if (!CountValid)
409       return PGOEdge::infoString();
410     return (Twine(PGOEdge::infoString()) + "  Count=" + Twine(CountValue))
411         .str();
412   }
413 };
414
415 typedef SmallVector<PGOUseEdge *, 2> DirectEdges;
416
417 // This class stores the auxiliary information for each BB.
418 struct UseBBInfo : public BBInfo {
419   uint64_t CountValue;
420   bool CountValid;
421   int32_t UnknownCountInEdge;
422   int32_t UnknownCountOutEdge;
423   DirectEdges InEdges;
424   DirectEdges OutEdges;
425   UseBBInfo(unsigned IX)
426       : BBInfo(IX), CountValue(0), CountValid(false), UnknownCountInEdge(0),
427         UnknownCountOutEdge(0) {}
428   UseBBInfo(unsigned IX, uint64_t C)
429       : BBInfo(IX), CountValue(C), CountValid(true), UnknownCountInEdge(0),
430         UnknownCountOutEdge(0) {}
431
432   // Set the profile count value for this BB.
433   void setBBInfoCount(uint64_t Value) {
434     CountValue = Value;
435     CountValid = true;
436   }
437
438   // Return the information string of this object.
439   const std::string infoString() const {
440     if (!CountValid)
441       return BBInfo::infoString();
442     return (Twine(BBInfo::infoString()) + "  Count=" + Twine(CountValue)).str();
443   }
444 };
445
446 // Sum up the count values for all the edges.
447 static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) {
448   uint64_t Total = 0;
449   for (auto &E : Edges) {
450     if (E->Removed)
451       continue;
452     Total += E->CountValue;
453   }
454   return Total;
455 }
456
457 class PGOUseFunc {
458 public:
459   PGOUseFunc(Function &Func, Module *Modu, BranchProbabilityInfo *BPI = nullptr,
460              BlockFrequencyInfo *BFI = nullptr)
461       : F(Func), M(Modu), FuncInfo(Func, false, BPI, BFI),
462         FreqAttr(FFA_Normal) {}
463
464   // Read counts for the instrumented BB from profile.
465   bool readCounters(IndexedInstrProfReader *PGOReader);
466
467   // Populate the counts for all BBs.
468   void populateCounters();
469
470   // Set the branch weights based on the count values.
471   void setBranchWeights();
472
473   // Annotate the indirect call sites.
474   void annotateIndirectCallSites();
475
476   // The hotness of the function from the profile count.
477   enum FuncFreqAttr { FFA_Normal, FFA_Cold, FFA_Hot };
478
479   // Return the function hotness from the profile.
480   FuncFreqAttr getFuncFreqAttr() const { return FreqAttr; }
481
482   // Return the profile record for this function;
483   InstrProfRecord &getProfileRecord() { return ProfileRecord; }
484
485 private:
486   Function &F;
487   Module *M;
488   // This member stores the shared information with class PGOGenFunc.
489   FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo;
490
491   // Return the auxiliary BB information.
492   UseBBInfo &getBBInfo(const BasicBlock *BB) const {
493     return FuncInfo.getBBInfo(BB);
494   }
495
496   // The maximum count value in the profile. This is only used in PGO use
497   // compilation.
498   uint64_t ProgramMaxCount;
499
500   // ProfileRecord for this function.
501   InstrProfRecord ProfileRecord;
502
503   // Function hotness info derived from profile.
504   FuncFreqAttr FreqAttr;
505
506   // Find the Instrumented BB and set the value.
507   void setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile);
508
509   // Set the edge counter value for the unknown edge -- there should be only
510   // one unknown edge.
511   void setEdgeCount(DirectEdges &Edges, uint64_t Value);
512
513   // Return FuncName string;
514   const std::string getFuncName() const { return FuncInfo.FuncName; }
515
516   // Set the hot/cold inline hints based on the count values.
517   // FIXME: This function should be removed once the functionality in
518   // the inliner is implemented.
519   void markFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) {
520     if (ProgramMaxCount == 0)
521       return;
522     // Threshold of the hot functions.
523     const BranchProbability HotFunctionThreshold(1, 100);
524     // Threshold of the cold functions.
525     const BranchProbability ColdFunctionThreshold(2, 10000);
526     if (EntryCount >= HotFunctionThreshold.scale(ProgramMaxCount))
527       FreqAttr = FFA_Hot;
528     else if (MaxCount <= ColdFunctionThreshold.scale(ProgramMaxCount))
529       FreqAttr = FFA_Cold;
530   }
531 };
532
533 // Visit all the edges and assign the count value for the instrumented
534 // edges and the BB.
535 void PGOUseFunc::setInstrumentedCounts(
536     const std::vector<uint64_t> &CountFromProfile) {
537
538   // Use a worklist as we will update the vector during the iteration.
539   std::vector<PGOUseEdge *> WorkList;
540   for (auto &E : FuncInfo.MST.AllEdges)
541     WorkList.push_back(E.get());
542
543   uint32_t I = 0;
544   for (auto &E : WorkList) {
545     BasicBlock *InstrBB = FuncInfo.getInstrBB(E);
546     if (!InstrBB)
547       continue;
548     uint64_t CountValue = CountFromProfile[I++];
549     if (!E->Removed) {
550       getBBInfo(InstrBB).setBBInfoCount(CountValue);
551       E->setEdgeCount(CountValue);
552       continue;
553     }
554
555     // Need to add two new edges.
556     BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
557     BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
558     // Add new edge of SrcBB->InstrBB.
559     PGOUseEdge &NewEdge = FuncInfo.MST.addEdge(SrcBB, InstrBB, 0);
560     NewEdge.setEdgeCount(CountValue);
561     // Add new edge of InstrBB->DestBB.
562     PGOUseEdge &NewEdge1 = FuncInfo.MST.addEdge(InstrBB, DestBB, 0);
563     NewEdge1.setEdgeCount(CountValue);
564     NewEdge1.InMST = true;
565     getBBInfo(InstrBB).setBBInfoCount(CountValue);
566   }
567 }
568
569 // Set the count value for the unknown edge. There should be one and only one
570 // unknown edge in Edges vector.
571 void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) {
572   for (auto &E : Edges) {
573     if (E->CountValid)
574       continue;
575     E->setEdgeCount(Value);
576
577     getBBInfo(E->SrcBB).UnknownCountOutEdge--;
578     getBBInfo(E->DestBB).UnknownCountInEdge--;
579     return;
580   }
581   llvm_unreachable("Cannot find the unknown count edge");
582 }
583
584 // Read the profile from ProfileFileName and assign the value to the
585 // instrumented BB and the edges. This function also updates ProgramMaxCount.
586 // Return true if the profile are successfully read, and false on errors.
587 bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) {
588   auto &Ctx = M->getContext();
589   Expected<InstrProfRecord> Result =
590       PGOReader->getInstrProfRecord(FuncInfo.FuncName, FuncInfo.FunctionHash);
591   if (Error E = Result.takeError()) {
592     handleAllErrors(std::move(E), [&](const InstrProfError &IPE) {
593       auto Err = IPE.get();
594       bool SkipWarning = false;
595       if (Err == instrprof_error::unknown_function) {
596         NumOfPGOMissing++;
597         SkipWarning = NoPGOWarnMissing;
598       } else if (Err == instrprof_error::hash_mismatch ||
599                  Err == instrprof_error::malformed) {
600         NumOfPGOMismatch++;
601         SkipWarning = NoPGOWarnMismatch;
602       }
603
604       if (SkipWarning)
605         return;
606
607       std::string Msg = IPE.message() + std::string(" ") + F.getName().str();
608       Ctx.diagnose(
609           DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
610     });
611     return false;
612   }
613   ProfileRecord = std::move(Result.get());
614   std::vector<uint64_t> &CountFromProfile = ProfileRecord.Counts;
615
616   NumOfPGOFunc++;
617   DEBUG(dbgs() << CountFromProfile.size() << " counts\n");
618   uint64_t ValueSum = 0;
619   for (unsigned I = 0, S = CountFromProfile.size(); I < S; I++) {
620     DEBUG(dbgs() << "  " << I << ": " << CountFromProfile[I] << "\n");
621     ValueSum += CountFromProfile[I];
622   }
623
624   DEBUG(dbgs() << "SUM =  " << ValueSum << "\n");
625
626   getBBInfo(nullptr).UnknownCountOutEdge = 2;
627   getBBInfo(nullptr).UnknownCountInEdge = 2;
628
629   setInstrumentedCounts(CountFromProfile);
630   ProgramMaxCount = PGOReader->getMaximumFunctionCount();
631   return true;
632 }
633
634 // Populate the counters from instrumented BBs to all BBs.
635 // In the end of this operation, all BBs should have a valid count value.
636 void PGOUseFunc::populateCounters() {
637   // First set up Count variable for all BBs.
638   for (auto &E : FuncInfo.MST.AllEdges) {
639     if (E->Removed)
640       continue;
641
642     const BasicBlock *SrcBB = E->SrcBB;
643     const BasicBlock *DestBB = E->DestBB;
644     UseBBInfo &SrcInfo = getBBInfo(SrcBB);
645     UseBBInfo &DestInfo = getBBInfo(DestBB);
646     SrcInfo.OutEdges.push_back(E.get());
647     DestInfo.InEdges.push_back(E.get());
648     SrcInfo.UnknownCountOutEdge++;
649     DestInfo.UnknownCountInEdge++;
650
651     if (!E->CountValid)
652       continue;
653     DestInfo.UnknownCountInEdge--;
654     SrcInfo.UnknownCountOutEdge--;
655   }
656
657   bool Changes = true;
658   unsigned NumPasses = 0;
659   while (Changes) {
660     NumPasses++;
661     Changes = false;
662
663     // For efficient traversal, it's better to start from the end as most
664     // of the instrumented edges are at the end.
665     for (auto &BB : reverse(F)) {
666       UseBBInfo &Count = getBBInfo(&BB);
667       if (!Count.CountValid) {
668         if (Count.UnknownCountOutEdge == 0) {
669           Count.CountValue = sumEdgeCount(Count.OutEdges);
670           Count.CountValid = true;
671           Changes = true;
672         } else if (Count.UnknownCountInEdge == 0) {
673           Count.CountValue = sumEdgeCount(Count.InEdges);
674           Count.CountValid = true;
675           Changes = true;
676         }
677       }
678       if (Count.CountValid) {
679         if (Count.UnknownCountOutEdge == 1) {
680           uint64_t Total = Count.CountValue - sumEdgeCount(Count.OutEdges);
681           setEdgeCount(Count.OutEdges, Total);
682           Changes = true;
683         }
684         if (Count.UnknownCountInEdge == 1) {
685           uint64_t Total = Count.CountValue - sumEdgeCount(Count.InEdges);
686           setEdgeCount(Count.InEdges, Total);
687           Changes = true;
688         }
689       }
690     }
691   }
692
693   DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n");
694 #ifndef NDEBUG
695   // Assert every BB has a valid counter.
696   for (auto &BB : F)
697     assert(getBBInfo(&BB).CountValid && "BB count is not valid");
698 #endif
699   uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
700   F.setEntryCount(FuncEntryCount);
701   uint64_t FuncMaxCount = FuncEntryCount;
702   for (auto &BB : F)
703     FuncMaxCount = std::max(FuncMaxCount, getBBInfo(&BB).CountValue);
704   markFunctionAttributes(FuncEntryCount, FuncMaxCount);
705
706   DEBUG(FuncInfo.dumpInfo("after reading profile."));
707 }
708
709 // Assign the scaled count values to the BB with multiple out edges.
710 void PGOUseFunc::setBranchWeights() {
711   // Generate MD_prof metadata for every branch instruction.
712   DEBUG(dbgs() << "\nSetting branch weights.\n");
713   MDBuilder MDB(M->getContext());
714   for (auto &BB : F) {
715     TerminatorInst *TI = BB.getTerminator();
716     if (TI->getNumSuccessors() < 2)
717       continue;
718     if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI))
719       continue;
720     if (getBBInfo(&BB).CountValue == 0)
721       continue;
722
723     // We have a non-zero Branch BB.
724     const UseBBInfo &BBCountInfo = getBBInfo(&BB);
725     unsigned Size = BBCountInfo.OutEdges.size();
726     SmallVector<unsigned, 2> EdgeCounts(Size, 0);
727     uint64_t MaxCount = 0;
728     for (unsigned s = 0; s < Size; s++) {
729       const PGOUseEdge *E = BBCountInfo.OutEdges[s];
730       const BasicBlock *SrcBB = E->SrcBB;
731       const BasicBlock *DestBB = E->DestBB;
732       if (DestBB == nullptr)
733         continue;
734       unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
735       uint64_t EdgeCount = E->CountValue;
736       if (EdgeCount > MaxCount)
737         MaxCount = EdgeCount;
738       EdgeCounts[SuccNum] = EdgeCount;
739     }
740     assert(MaxCount > 0 && "Bad max count");
741     uint64_t Scale = calculateCountScale(MaxCount);
742     SmallVector<unsigned, 4> Weights;
743     for (const auto &ECI : EdgeCounts)
744       Weights.push_back(scaleBranchCount(ECI, Scale));
745
746     TI->setMetadata(llvm::LLVMContext::MD_prof,
747                     MDB.createBranchWeights(Weights));
748     DEBUG(dbgs() << "Weight is: ";
749           for (const auto &W : Weights) { dbgs() << W << " "; }
750           dbgs() << "\n";);
751   }
752 }
753
754 // Traverse all the indirect callsites and annotate the instructions.
755 void PGOUseFunc::annotateIndirectCallSites() {
756   if (DisableValueProfiling)
757     return;
758
759   // Create the PGOFuncName meta data.
760   createPGOFuncNameMetadata(F, FuncInfo.FuncName);
761
762   unsigned IndirectCallSiteIndex = 0;
763   auto IndirectCallSites = findIndirectCallSites(F);
764   unsigned NumValueSites =
765       ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget);
766   if (NumValueSites != IndirectCallSites.size()) {
767     std::string Msg =
768         std::string("Inconsistent number of indirect call sites: ") +
769         F.getName().str();
770     auto &Ctx = M->getContext();
771     Ctx.diagnose(
772         DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
773     return;
774   }
775
776   for (auto &I : IndirectCallSites) {
777     DEBUG(dbgs() << "Read one indirect call instrumentation: Index="
778                  << IndirectCallSiteIndex << " out of " << NumValueSites
779                  << "\n");
780     annotateValueSite(*M, *I, ProfileRecord, IPVK_IndirectCallTarget,
781                       IndirectCallSiteIndex, MaxNumAnnotations);
782     IndirectCallSiteIndex++;
783   }
784 }
785 } // end anonymous namespace
786
787 // Create a COMDAT variable IR_LEVEL_PROF_VARNAME to make the runtime
788 // aware this is an ir_level profile so it can set the version flag.
789 static void createIRLevelProfileFlagVariable(Module &M) {
790   Type *IntTy64 = Type::getInt64Ty(M.getContext());
791   uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
792   auto IRLevelVersionVariable = new GlobalVariable(
793       M, IntTy64, true, GlobalVariable::ExternalLinkage,
794       Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)),
795       INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR));
796   IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility);
797   Triple TT(M.getTargetTriple());
798   if (!TT.supportsCOMDAT())
799     IRLevelVersionVariable->setLinkage(GlobalValue::WeakAnyLinkage);
800   else
801     IRLevelVersionVariable->setComdat(M.getOrInsertComdat(
802         StringRef(INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR))));
803 }
804
805 static bool InstrumentAllFunctions(
806     Module &M, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
807     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
808   createIRLevelProfileFlagVariable(M);
809   for (auto &F : M) {
810     if (F.isDeclaration())
811       continue;
812     auto *BPI = LookupBPI(F);
813     auto *BFI = LookupBFI(F);
814     instrumentOneFunc(F, &M, BPI, BFI);
815   }
816   return true;
817 }
818
819 bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) {
820   if (skipModule(M))
821     return false;
822
823   auto LookupBPI = [this](Function &F) {
824     return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
825   };
826   auto LookupBFI = [this](Function &F) {
827     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
828   };
829   return InstrumentAllFunctions(M, LookupBPI, LookupBFI);
830 }
831
832 PreservedAnalyses PGOInstrumentationGen::run(Module &M,
833                                              AnalysisManager<Module> &AM) {
834
835   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
836   auto LookupBPI = [&FAM](Function &F) {
837     return &FAM.getResult<BranchProbabilityAnalysis>(F);
838   };
839
840   auto LookupBFI = [&FAM](Function &F) {
841     return &FAM.getResult<BlockFrequencyAnalysis>(F);
842   };
843
844   if (!InstrumentAllFunctions(M, LookupBPI, LookupBFI))
845     return PreservedAnalyses::all();
846
847   return PreservedAnalyses::none();
848 }
849
850 static bool annotateAllFunctions(
851     Module &M, StringRef ProfileFileName,
852     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
853     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
854   DEBUG(dbgs() << "Read in profile counters: ");
855   auto &Ctx = M.getContext();
856   // Read the counter array from file.
857   auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName);
858   if (Error E = ReaderOrErr.takeError()) {
859     handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) {
860       Ctx.diagnose(
861           DiagnosticInfoPGOProfile(ProfileFileName.data(), EI.message()));
862     });
863     return false;
864   }
865
866   std::unique_ptr<IndexedInstrProfReader> PGOReader =
867       std::move(ReaderOrErr.get());
868   if (!PGOReader) {
869     Ctx.diagnose(DiagnosticInfoPGOProfile(ProfileFileName.data(),
870                                           StringRef("Cannot get PGOReader")));
871     return false;
872   }
873   // TODO: might need to change the warning once the clang option is finalized.
874   if (!PGOReader->isIRLevelProfile()) {
875     Ctx.diagnose(DiagnosticInfoPGOProfile(
876         ProfileFileName.data(), "Not an IR level instrumentation profile"));
877     return false;
878   }
879
880   std::vector<Function *> HotFunctions;
881   std::vector<Function *> ColdFunctions;
882   for (auto &F : M) {
883     if (F.isDeclaration())
884       continue;
885     auto *BPI = LookupBPI(F);
886     auto *BFI = LookupBFI(F);
887     PGOUseFunc Func(F, &M, BPI, BFI);
888     if (!Func.readCounters(PGOReader.get()))
889       continue;
890     Func.populateCounters();
891     Func.setBranchWeights();
892     Func.annotateIndirectCallSites();
893     PGOUseFunc::FuncFreqAttr FreqAttr = Func.getFuncFreqAttr();
894     if (FreqAttr == PGOUseFunc::FFA_Cold)
895       ColdFunctions.push_back(&F);
896     else if (FreqAttr == PGOUseFunc::FFA_Hot)
897       HotFunctions.push_back(&F);
898   }
899   M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext()));
900   // Set function hotness attribute from the profile.
901   // We have to apply these attributes at the end because their presence
902   // can affect the BranchProbabilityInfo of any callers, resulting in an
903   // inconsistent MST between prof-gen and prof-use.
904   for (auto &F : HotFunctions) {
905     F->addFnAttr(llvm::Attribute::InlineHint);
906     DEBUG(dbgs() << "Set inline attribute to function: " << F->getName()
907                  << "\n");
908   }
909   for (auto &F : ColdFunctions) {
910     F->addFnAttr(llvm::Attribute::Cold);
911     DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() << "\n");
912   }
913
914   return true;
915 }
916
917 PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename)
918     : ProfileFileName(std::move(Filename)) {
919   if (!PGOTestProfileFile.empty())
920     ProfileFileName = PGOTestProfileFile;
921 }
922
923 PreservedAnalyses PGOInstrumentationUse::run(Module &M,
924                                              AnalysisManager<Module> &AM) {
925
926   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
927   auto LookupBPI = [&FAM](Function &F) {
928     return &FAM.getResult<BranchProbabilityAnalysis>(F);
929   };
930
931   auto LookupBFI = [&FAM](Function &F) {
932     return &FAM.getResult<BlockFrequencyAnalysis>(F);
933   };
934
935   if (!annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI))
936     return PreservedAnalyses::all();
937
938   return PreservedAnalyses::none();
939 }
940
941 bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) {
942   if (skipModule(M))
943     return false;
944
945   auto LookupBPI = [this](Function &F) {
946     return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
947   };
948   auto LookupBFI = [this](Function &F) {
949     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
950   };
951
952   return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI);
953 }