1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H
12 #include "llvm/Analysis/CallGraph.h"
13 #include "llvm/Analysis/InlineAdvisor.h"
14 #include "llvm/Analysis/MLModelRunner.h"
15 #include "llvm/IR/PassManager.h"
18 #include <unordered_map>
24 class MLInlineAdvisor : public InlineAdvisor {
26 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
27 std::unique_ptr<MLModelRunner> ModelRunner);
29 CallGraph *callGraph() const { return CG.get(); }
30 virtual ~MLInlineAdvisor() = default;
32 void onPassEntry() override;
34 std::unique_ptr<InlineAdvice> getAdvice(CallBase &CB) override;
36 int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); }
37 void onSuccessfulInlining(const MLInlineAdvice &Advice,
38 bool CalleeWasDeleted);
40 bool isForcedToStop() const { return ForceStop; }
41 int64_t getLocalCalls(Function &F);
42 const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
45 virtual std::unique_ptr<MLInlineAdvice>
46 getMandatoryAdvice(CallBase &CB, OptimizationRemarkEmitter &ORE);
48 virtual std::unique_ptr<MLInlineAdvice>
49 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
52 std::unique_ptr<MLModelRunner> ModelRunner;
55 int64_t getModuleIRSize() const;
57 std::unique_ptr<CallGraph> CG;
59 int64_t NodeCount = 0;
60 int64_t EdgeCount = 0;
61 std::map<const Function *, unsigned> FunctionLevels;
62 const int32_t InitialIRSize = 0;
63 int32_t CurrentIRSize = 0;
65 bool ForceStop = false;
68 /// InlineAdvice that tracks changes post inlining. For that reason, it only
69 /// overrides the "successful inlining" extension points.
70 class MLInlineAdvice : public InlineAdvice {
72 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
73 OptimizationRemarkEmitter &ORE, bool Recommendation)
74 : InlineAdvice(Advisor, CB, ORE, Recommendation),
75 CallerIRSize(Advisor->isForcedToStop() ? 0
76 : Advisor->getIRSize(*Caller)),
77 CalleeIRSize(Advisor->isForcedToStop() ? 0
78 : Advisor->getIRSize(*Callee)),
79 CallerAndCalleeEdges(Advisor->isForcedToStop()
81 : (Advisor->getLocalCalls(*Caller) +
82 Advisor->getLocalCalls(*Callee))) {}
83 virtual ~MLInlineAdvice() = default;
85 void recordInliningImpl() override;
86 void recordInliningWithCalleeDeletedImpl() override;
87 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
88 void recordUnattemptedInliningImpl() override;
90 Function *getCaller() const { return Caller; }
91 Function *getCallee() const { return Callee; }
93 const int64_t CallerIRSize;
94 const int64_t CalleeIRSize;
95 const int64_t CallerAndCalleeEdges;
98 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
100 MLInlineAdvisor *getAdvisor() const {
101 return static_cast<MLInlineAdvisor *>(Advisor);
107 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H