1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 // Set Load/Store Alignments From Assumptions
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
16 //===----------------------------------------------------------------------===//
18 #include "llvm/InitializePasses.h"
19 #define AA_NAME "alignment-from-assumptions"
20 #define DEBUG_TYPE AA_NAME
21 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/GlobalsModRef.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
29 #include "llvm/Analysis/ValueTracking.h"
30 #include "llvm/IR/Constant.h"
31 #include "llvm/IR/Dominators.h"
32 #include "llvm/IR/Instruction.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/Transforms/Scalar.h"
41 STATISTIC(NumLoadAlignChanged,
42 "Number of loads changed by alignment assumptions");
43 STATISTIC(NumStoreAlignChanged,
44 "Number of stores changed by alignment assumptions");
45 STATISTIC(NumMemIntAlignChanged,
46 "Number of memory intrinsics changed by alignment assumptions");
49 struct AlignmentFromAssumptions : public FunctionPass {
50 static char ID; // Pass identification, replacement for typeid
51 AlignmentFromAssumptions() : FunctionPass(ID) {
52 initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
55 bool runOnFunction(Function &F) override;
57 void getAnalysisUsage(AnalysisUsage &AU) const override {
58 AU.addRequired<AssumptionCacheTracker>();
59 AU.addRequired<ScalarEvolutionWrapperPass>();
60 AU.addRequired<DominatorTreeWrapperPass>();
63 AU.addPreserved<AAResultsWrapperPass>();
64 AU.addPreserved<GlobalsAAWrapperPass>();
65 AU.addPreserved<LoopInfoWrapperPass>();
66 AU.addPreserved<DominatorTreeWrapperPass>();
67 AU.addPreserved<ScalarEvolutionWrapperPass>();
70 AlignmentFromAssumptionsPass Impl;
74 char AlignmentFromAssumptions::ID = 0;
75 static const char aip_name[] = "Alignment from assumptions";
76 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
77 aip_name, false, false)
78 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
79 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
80 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
81 INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
82 aip_name, false, false)
84 FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
85 return new AlignmentFromAssumptions();
88 // Given an expression for the (constant) alignment, AlignSCEV, and an
89 // expression for the displacement between a pointer and the aligned address,
90 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
91 // to a constant. Using SCEV to compute alignment handles the case where
92 // DiffSCEV is a recurrence with constant start such that the aligned offset
93 // is constant. e.g. {16,+,32} % 32 -> 16.
94 static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,
95 const SCEV *AlignSCEV,
96 ScalarEvolution *SE) {
97 // DiffUnits = Diff % int64_t(Alignment)
98 const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
100 LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
101 << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
103 if (const SCEVConstant *ConstDUSCEV =
104 dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
105 int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
107 // If the displacement is an exact multiple of the alignment, then the
108 // displaced pointer has the same alignment as the aligned pointer, so
109 // return the alignment value.
111 return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();
113 // If the displacement is not an exact multiple, but the remainder is a
114 // constant, then return this remainder (but only if it is a power of 2).
115 uint64_t DiffUnitsAbs = std::abs(DiffUnits);
116 if (isPowerOf2_64(DiffUnitsAbs))
117 return Align(DiffUnitsAbs);
123 // There is an address given by an offset OffSCEV from AASCEV which has an
124 // alignment AlignSCEV. Use that information, if possible, to compute a new
125 // alignment for Ptr.
126 static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
127 const SCEV *OffSCEV, Value *Ptr,
128 ScalarEvolution *SE) {
129 const SCEV *PtrSCEV = SE->getSCEV(Ptr);
130 // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
131 // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
132 // may disagree. Trunc/extend so they agree.
133 PtrSCEV = SE->getTruncateOrZeroExtend(
134 PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
135 const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
137 // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
138 // sign-extended OffSCEV to i64, so make sure they agree again.
139 DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
141 // What we really want to know is the overall offset to the aligned
142 // address. This address is displaced by the provided offset.
143 DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV);
145 LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "
146 << *AlignSCEV << " and offset " << *OffSCEV
147 << " using diff " << *DiffSCEV << "\n");
149 if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {
150 LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");
151 return *NewAlignment;
154 if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
155 // The relative offset to the alignment assumption did not yield a constant,
156 // but we should try harder: if we assume that a is 32-byte aligned, then in
157 // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
158 // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
159 // As a result, the new alignment will not be a constant, but can still
160 // be improved over the default (of 4) to 16.
162 const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
163 const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
165 LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
166 << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
168 // Now compute the new alignment using the displacement to the value in the
169 // first iteration, and also the alignment using the per-iteration delta.
170 // If these are the same, then use that answer. Otherwise, use the smaller
171 // one, but only if it divides the larger one.
172 MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
173 MaybeAlign NewIncAlignment =
174 getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
176 LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)
178 LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)
181 if (!NewAlignment || !NewIncAlignment)
184 const Align NewAlign = *NewAlignment;
185 const Align NewIncAlign = *NewIncAlignment;
186 if (NewAlign > NewIncAlign) {
187 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
188 << DebugStr(NewIncAlign) << "\n");
191 if (NewIncAlign > NewAlign) {
192 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
196 assert(NewIncAlign == NewAlign);
197 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
205 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
207 const SCEV *&AlignSCEV,
208 const SCEV *&OffSCEV) {
209 // An alignment assume must be a statement about the least-significant
210 // bits of the pointer being zero, possibly with some offset.
211 ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
215 // This must be an expression of the form: x & m == 0.
216 if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
219 // Swap things around so that the RHS is 0.
220 Value *CmpLHS = ICI->getOperand(0);
221 Value *CmpRHS = ICI->getOperand(1);
222 const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
223 const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
224 if (CmpLHSSCEV->isZero())
225 std::swap(CmpLHS, CmpRHS);
226 else if (!CmpRHSSCEV->isZero())
229 BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
230 if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
233 // Swap things around so that the right operand of the and is a constant
234 // (the mask); we cannot deal with variable masks.
235 Value *AndLHS = CmpBO->getOperand(0);
236 Value *AndRHS = CmpBO->getOperand(1);
237 const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
238 const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
239 if (isa<SCEVConstant>(AndLHSSCEV)) {
240 std::swap(AndLHS, AndRHS);
241 std::swap(AndLHSSCEV, AndRHSSCEV);
244 const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
248 // The mask must have some trailing ones (otherwise the condition is
249 // trivial and tells us nothing about the alignment of the left operand).
250 unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
254 // Cap the alignment at the maximum with which LLVM can deal (and make sure
255 // we don't overflow the shift).
257 TrailingOnes = std::min(TrailingOnes,
258 unsigned(sizeof(unsigned) * CHAR_BIT - 1));
259 Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
261 Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
262 AlignSCEV = SE->getConstant(Int64Ty, Alignment);
264 // The LHS might be a ptrtoint instruction, or it might be the pointer
268 if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
269 AAPtr = PToI->getPointerOperand();
270 OffSCEV = SE->getZero(Int64Ty);
271 } else if (const SCEVAddExpr* AndLHSAddSCEV =
272 dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
273 // Try to find the ptrtoint; subtract it and the rest is the offset.
274 for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
275 JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
276 if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
277 if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
278 AAPtr = PToI->getPointerOperand();
279 OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
287 // Sign extend the offset to 64 bits (so that it is like all of the other
289 unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
290 if (OffSCEVBits < 64)
291 OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
292 else if (OffSCEVBits > 64)
295 AAPtr = AAPtr->stripPointerCasts();
299 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
301 const SCEV *AlignSCEV, *OffSCEV;
302 if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
305 // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
306 // affect other users.
307 if (isa<ConstantData>(AAPtr))
310 const SCEV *AASCEV = SE->getSCEV(AAPtr);
312 // Apply the assumption to all other users of the specified pointer.
313 SmallPtrSet<Instruction *, 32> Visited;
314 SmallVector<Instruction*, 16> WorkList;
315 for (User *J : AAPtr->users()) {
319 if (Instruction *K = dyn_cast<Instruction>(J))
320 if (isValidAssumeForContext(ACall, K, DT))
321 WorkList.push_back(K);
324 while (!WorkList.empty()) {
325 Instruction *J = WorkList.pop_back_val();
326 if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
327 Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
328 LI->getPointerOperand(), SE);
329 if (NewAlignment > LI->getAlign()) {
330 LI->setAlignment(NewAlignment);
331 ++NumLoadAlignChanged;
333 } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
334 Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
335 SI->getPointerOperand(), SE);
336 if (NewAlignment > SI->getAlign()) {
337 SI->setAlignment(NewAlignment);
338 ++NumStoreAlignChanged;
340 } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
341 Align NewDestAlignment =
342 getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
344 LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)
346 if (NewDestAlignment > *MI->getDestAlign()) {
347 MI->setDestAlignment(NewDestAlignment);
348 ++NumMemIntAlignChanged;
351 // For memory transfers, there is also a source alignment that
353 if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
354 Align NewSrcAlignment =
355 getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
357 LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)
360 if (NewSrcAlignment > *MTI->getSourceAlign()) {
361 MTI->setSourceAlignment(NewSrcAlignment);
362 ++NumMemIntAlignChanged;
367 // Now that we've updated that use of the pointer, look for other uses of
368 // the pointer to update.
370 for (User *UJ : J->users()) {
371 Instruction *K = cast<Instruction>(UJ);
372 if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
373 WorkList.push_back(K);
380 bool AlignmentFromAssumptions::runOnFunction(Function &F) {
384 auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
385 ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
386 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
388 return Impl.runImpl(F, AC, SE, DT);
391 bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
392 ScalarEvolution *SE_,
393 DominatorTree *DT_) {
397 bool Changed = false;
398 for (auto &AssumeVH : AC.assumptions())
400 Changed |= processAssumption(cast<CallInst>(AssumeVH));
406 AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
408 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
409 ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
410 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
411 if (!runImpl(F, AC, &SE, &DT))
412 return PreservedAnalyses::all();
414 PreservedAnalyses PA;
415 PA.preserveSet<CFGAnalyses>();
416 PA.preserve<AAManager>();
417 PA.preserve<ScalarEvolutionAnalysis>();
418 PA.preserve<GlobalsAA>();