1 //===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file contains a pass that keeps track of @llvm.assume intrinsics in
11 // the functions of a module.
13 //===----------------------------------------------------------------------===//
15 #include "llvm/Analysis/AssumptionCache.h"
16 #include "llvm/IR/CallSite.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/Function.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/IR/PassManager.h"
22 #include "llvm/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
25 using namespace llvm::PatternMatch;
27 SmallVector<WeakVH, 1> &AssumptionCache::getAffectedValues(Value *V) {
28 // Try using find_as first to avoid creating extra value handles just for the
29 // purpose of doing the lookup.
30 auto AVI = AffectedValues.find_as(V);
31 if (AVI != AffectedValues.end())
34 auto AVIP = AffectedValues.insert({
35 AffectedValueCallbackVH(V, this), SmallVector<WeakVH, 1>()});
36 return AVIP.first->second;
39 void AssumptionCache::updateAffectedValues(CallInst *CI) {
40 // Note: This code must be kept in-sync with the code in
41 // computeKnownBitsFromAssume in ValueTracking.
43 SmallVector<Value *, 16> Affected;
44 auto AddAffected = [&Affected](Value *V) {
45 if (isa<Argument>(V)) {
46 Affected.push_back(V);
47 } else if (auto *I = dyn_cast<Instruction>(V)) {
48 Affected.push_back(I);
50 if (I->getOpcode() == Instruction::BitCast ||
51 I->getOpcode() == Instruction::PtrToInt) {
52 auto *Op = I->getOperand(0);
53 if (isa<Instruction>(Op) || isa<Argument>(Op))
54 Affected.push_back(Op);
59 Value *Cond = CI->getArgOperand(0), *A, *B;
62 CmpInst::Predicate Pred;
63 if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
67 if (Pred == ICmpInst::ICMP_EQ) {
68 // For equality comparisons, we handle the case of bit inversion.
69 auto AddAffectedFromEq = [&AddAffected](Value *V) {
71 if (match(V, m_Not(m_Value(A)))) {
78 // (A & B) or (A | B) or (A ^ B).
80 m_CombineOr(m_And(m_Value(A), m_Value(B)),
81 m_CombineOr(m_Or(m_Value(A), m_Value(B)),
82 m_Xor(m_Value(A), m_Value(B)))))) {
85 // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
87 m_CombineOr(m_Shl(m_Value(A), m_ConstantInt(C)),
88 m_CombineOr(m_LShr(m_Value(A), m_ConstantInt(C)),
90 m_ConstantInt(C)))))) {
100 for (auto &AV : Affected) {
101 auto &AVV = getAffectedValues(AV);
102 if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end())
107 void AssumptionCache::AffectedValueCallbackVH::deleted() {
108 auto AVI = AC->AffectedValues.find(getValPtr());
109 if (AVI != AC->AffectedValues.end())
110 AC->AffectedValues.erase(AVI);
111 // 'this' now dangles!
114 void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
115 if (!isa<Instruction>(NV) && !isa<Argument>(NV))
118 // Any assumptions that affected this value now affect the new value.
120 auto &NAVV = AC->getAffectedValues(NV);
121 auto AVI = AC->AffectedValues.find(getValPtr());
122 if (AVI == AC->AffectedValues.end())
125 for (auto &A : AVI->second)
126 if (std::find(NAVV.begin(), NAVV.end(), A) == NAVV.end())
130 void AssumptionCache::scanFunction() {
131 assert(!Scanned && "Tried to scan the function twice!");
132 assert(AssumeHandles.empty() && "Already have assumes when scanning!");
134 // Go through all instructions in all blocks, add all calls to @llvm.assume
136 for (BasicBlock &B : F)
137 for (Instruction &II : B)
138 if (match(&II, m_Intrinsic<Intrinsic::assume>()))
139 AssumeHandles.push_back(&II);
141 // Mark the scan as complete.
144 // Update affected values.
145 for (auto &A : AssumeHandles)
146 updateAffectedValues(cast<CallInst>(A));
149 void AssumptionCache::registerAssumption(CallInst *CI) {
150 assert(match(CI, m_Intrinsic<Intrinsic::assume>()) &&
151 "Registered call does not call @llvm.assume");
153 // If we haven't scanned the function yet, just drop this assumption. It will
154 // be found when we scan later.
158 AssumeHandles.push_back(CI);
161 assert(CI->getParent() &&
162 "Cannot register @llvm.assume call not in a basic block");
163 assert(&F == CI->getParent()->getParent() &&
164 "Cannot register @llvm.assume call not in this function");
166 // We expect the number of assumptions to be small, so in an asserts build
167 // check that we don't accumulate duplicates and that all assumptions point
168 // to the same function.
169 SmallPtrSet<Value *, 16> AssumptionSet;
170 for (auto &VH : AssumeHandles) {
174 assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
175 "Cached assumption not inside this function!");
176 assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
177 "Cached something other than a call to @llvm.assume!");
178 assert(AssumptionSet.insert(VH).second &&
179 "Cache contains multiple copies of a call!");
183 updateAffectedValues(CI);
186 AnalysisKey AssumptionAnalysis::Key;
188 PreservedAnalyses AssumptionPrinterPass::run(Function &F,
189 FunctionAnalysisManager &AM) {
190 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
192 OS << "Cached assumptions for function: " << F.getName() << "\n";
193 for (auto &VH : AC.assumptions())
195 OS << " " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
197 return PreservedAnalyses::all();
200 void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
201 auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr()));
202 if (I != ACT->AssumptionCaches.end())
203 ACT->AssumptionCaches.erase(I);
204 // 'this' now dangles!
207 AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
208 // We probe the function map twice to try and avoid creating a value handle
209 // around the function in common cases. This makes insertion a bit slower,
210 // but if we have to insert we're going to scan the whole function so that
212 auto I = AssumptionCaches.find_as(&F);
213 if (I != AssumptionCaches.end())
216 // Ok, build a new cache by scanning the function, insert it and the value
217 // handle into our map, and return the newly populated cache.
218 auto IP = AssumptionCaches.insert(std::make_pair(
219 FunctionCallbackVH(&F, this), llvm::make_unique<AssumptionCache>(F)));
220 assert(IP.second && "Scanning function already in the map?");
221 return *IP.first->second;
224 void AssumptionCacheTracker::verifyAnalysis() const {
226 SmallPtrSet<const CallInst *, 4> AssumptionSet;
227 for (const auto &I : AssumptionCaches) {
228 for (auto &VH : I.second->assumptions())
230 AssumptionSet.insert(cast<CallInst>(VH));
232 for (const BasicBlock &B : cast<Function>(*I.first))
233 for (const Instruction &II : B)
234 if (match(&II, m_Intrinsic<Intrinsic::assume>()))
235 assert(AssumptionSet.count(cast<CallInst>(&II)) &&
236 "Assumption in scanned function not in cache");
241 AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {
242 initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry());
245 AssumptionCacheTracker::~AssumptionCacheTracker() {}
247 INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
248 "Assumption Cache Tracker", false, true)
249 char AssumptionCacheTracker::ID = 0;