1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 // Set Load/Store Alignments From Assumptions
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
16 //===----------------------------------------------------------------------===//
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/InitializePasses.h"
20 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/AliasAnalysis.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/GlobalsModRef.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/Constant.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include "llvm/Transforms/Scalar.h"
39 #define AA_NAME "alignment-from-assumptions"
40 #define DEBUG_TYPE AA_NAME
43 STATISTIC(NumLoadAlignChanged,
44 "Number of loads changed by alignment assumptions");
45 STATISTIC(NumStoreAlignChanged,
46 "Number of stores changed by alignment assumptions");
47 STATISTIC(NumMemIntAlignChanged,
48 "Number of memory intrinsics changed by alignment assumptions");
51 struct AlignmentFromAssumptions : public FunctionPass {
52 static char ID; // Pass identification, replacement for typeid
53 AlignmentFromAssumptions() : FunctionPass(ID) {
54 initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
57 bool runOnFunction(Function &F) override;
59 void getAnalysisUsage(AnalysisUsage &AU) const override {
60 AU.addRequired<AssumptionCacheTracker>();
61 AU.addRequired<ScalarEvolutionWrapperPass>();
62 AU.addRequired<DominatorTreeWrapperPass>();
65 AU.addPreserved<AAResultsWrapperPass>();
66 AU.addPreserved<GlobalsAAWrapperPass>();
67 AU.addPreserved<LoopInfoWrapperPass>();
68 AU.addPreserved<DominatorTreeWrapperPass>();
69 AU.addPreserved<ScalarEvolutionWrapperPass>();
72 AlignmentFromAssumptionsPass Impl;
76 char AlignmentFromAssumptions::ID = 0;
77 static const char aip_name[] = "Alignment from assumptions";
78 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
79 aip_name, false, false)
80 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
81 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
82 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
83 INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
84 aip_name, false, false)
86 FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
87 return new AlignmentFromAssumptions();
90 // Given an expression for the (constant) alignment, AlignSCEV, and an
91 // expression for the displacement between a pointer and the aligned address,
92 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
93 // to a constant. Using SCEV to compute alignment handles the case where
94 // DiffSCEV is a recurrence with constant start such that the aligned offset
95 // is constant. e.g. {16,+,32} % 32 -> 16.
96 static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,
97 const SCEV *AlignSCEV,
98 ScalarEvolution *SE) {
99 // DiffUnits = Diff % int64_t(Alignment)
100 const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
102 LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
103 << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
105 if (const SCEVConstant *ConstDUSCEV =
106 dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
107 int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
109 // If the displacement is an exact multiple of the alignment, then the
110 // displaced pointer has the same alignment as the aligned pointer, so
111 // return the alignment value.
113 return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();
115 // If the displacement is not an exact multiple, but the remainder is a
116 // constant, then return this remainder (but only if it is a power of 2).
117 uint64_t DiffUnitsAbs = std::abs(DiffUnits);
118 if (isPowerOf2_64(DiffUnitsAbs))
119 return Align(DiffUnitsAbs);
125 // There is an address given by an offset OffSCEV from AASCEV which has an
126 // alignment AlignSCEV. Use that information, if possible, to compute a new
127 // alignment for Ptr.
128 static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
129 const SCEV *OffSCEV, Value *Ptr,
130 ScalarEvolution *SE) {
131 const SCEV *PtrSCEV = SE->getSCEV(Ptr);
132 // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
133 // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
134 // may disagree. Trunc/extend so they agree.
135 PtrSCEV = SE->getTruncateOrZeroExtend(
136 PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
137 const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
138 if (isa<SCEVCouldNotCompute>(DiffSCEV))
141 // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
142 // sign-extended OffSCEV to i64, so make sure they agree again.
143 DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
145 // What we really want to know is the overall offset to the aligned
146 // address. This address is displaced by the provided offset.
147 DiffSCEV = SE->getAddExpr(DiffSCEV, OffSCEV);
149 LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "
150 << *AlignSCEV << " and offset " << *OffSCEV
151 << " using diff " << *DiffSCEV << "\n");
153 if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {
154 LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");
155 return *NewAlignment;
158 if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
159 // The relative offset to the alignment assumption did not yield a constant,
160 // but we should try harder: if we assume that a is 32-byte aligned, then in
161 // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
162 // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
163 // As a result, the new alignment will not be a constant, but can still
164 // be improved over the default (of 4) to 16.
166 const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
167 const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
169 LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
170 << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
172 // Now compute the new alignment using the displacement to the value in the
173 // first iteration, and also the alignment using the per-iteration delta.
174 // If these are the same, then use that answer. Otherwise, use the smaller
175 // one, but only if it divides the larger one.
176 MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
177 MaybeAlign NewIncAlignment =
178 getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
180 LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)
182 LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)
185 if (!NewAlignment || !NewIncAlignment)
188 const Align NewAlign = *NewAlignment;
189 const Align NewIncAlign = *NewIncAlignment;
190 if (NewAlign > NewIncAlign) {
191 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
192 << DebugStr(NewIncAlign) << "\n");
195 if (NewIncAlign > NewAlign) {
196 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
200 assert(NewIncAlign == NewAlign);
201 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
209 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
212 const SCEV *&AlignSCEV,
213 const SCEV *&OffSCEV) {
214 Type *Int64Ty = Type::getInt64Ty(I->getContext());
215 OperandBundleUse AlignOB = I->getOperandBundleAt(Idx);
216 if (AlignOB.getTagName() != "align")
218 assert(AlignOB.Inputs.size() >= 2);
219 AAPtr = AlignOB.Inputs[0].get();
220 // TODO: Consider accumulating the offset to the base.
221 AAPtr = AAPtr->stripPointerCastsSameRepresentation();
222 AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get());
223 AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);
224 if (!isa<SCEVConstant>(AlignSCEV))
225 // Added to suppress a crash because consumer doesn't expect non-constant
226 // alignments in the assume bundle. TODO: Consider generalizing caller.
228 if (AlignOB.Inputs.size() == 3)
229 OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get());
231 OffSCEV = SE->getZero(Int64Ty);
232 OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);
236 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall,
239 const SCEV *AlignSCEV, *OffSCEV;
240 if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV))
243 // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
244 // affect other users.
245 if (isa<ConstantData>(AAPtr))
248 const SCEV *AASCEV = SE->getSCEV(AAPtr);
250 // Apply the assumption to all other users of the specified pointer.
251 SmallPtrSet<Instruction *, 32> Visited;
252 SmallVector<Instruction*, 16> WorkList;
253 for (User *J : AAPtr->users()) {
257 if (Instruction *K = dyn_cast<Instruction>(J))
258 WorkList.push_back(K);
261 while (!WorkList.empty()) {
262 Instruction *J = WorkList.pop_back_val();
263 if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
264 if (!isValidAssumeForContext(ACall, J, DT))
266 Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
267 LI->getPointerOperand(), SE);
268 if (NewAlignment > LI->getAlign()) {
269 LI->setAlignment(NewAlignment);
270 ++NumLoadAlignChanged;
272 } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
273 if (!isValidAssumeForContext(ACall, J, DT))
275 Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
276 SI->getPointerOperand(), SE);
277 if (NewAlignment > SI->getAlign()) {
278 SI->setAlignment(NewAlignment);
279 ++NumStoreAlignChanged;
281 } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
282 if (!isValidAssumeForContext(ACall, J, DT))
284 Align NewDestAlignment =
285 getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
287 LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)
289 if (NewDestAlignment > *MI->getDestAlign()) {
290 MI->setDestAlignment(NewDestAlignment);
291 ++NumMemIntAlignChanged;
294 // For memory transfers, there is also a source alignment that
296 if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
297 Align NewSrcAlignment =
298 getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
300 LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)
303 if (NewSrcAlignment > *MTI->getSourceAlign()) {
304 MTI->setSourceAlignment(NewSrcAlignment);
305 ++NumMemIntAlignChanged;
310 // Now that we've updated that use of the pointer, look for other uses of
311 // the pointer to update.
313 for (User *UJ : J->users()) {
314 Instruction *K = cast<Instruction>(UJ);
315 if (!Visited.count(K))
316 WorkList.push_back(K);
323 bool AlignmentFromAssumptions::runOnFunction(Function &F) {
327 auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
328 ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
329 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
331 return Impl.runImpl(F, AC, SE, DT);
334 bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
335 ScalarEvolution *SE_,
336 DominatorTree *DT_) {
340 bool Changed = false;
341 for (auto &AssumeVH : AC.assumptions())
343 CallInst *Call = cast<CallInst>(AssumeVH);
344 for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
345 Changed |= processAssumption(Call, Idx);
352 AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
354 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
355 ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
356 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
357 if (!runImpl(F, AC, &SE, &DT))
358 return PreservedAnalyses::all();
360 PreservedAnalyses PA;
361 PA.preserveSet<CFGAnalyses>();
362 PA.preserve<ScalarEvolutionAnalysis>();