1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
37 //===----------------------------------------------------------------------===//
39 // There are several good references for the techniques used in this analysis.
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 // On computational properties of chains of recurrences
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
58 //===----------------------------------------------------------------------===//
60 #include "llvm/Analysis/ScalarEvolution.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
64 #include "llvm/ADT/DepthFirstIterator.h"
65 #include "llvm/ADT/EquivalenceClasses.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
78 #include "llvm/Analysis/AssumptionCache.h"
79 #include "llvm/Analysis/ConstantFolding.h"
80 #include "llvm/Analysis/InstructionSimplify.h"
81 #include "llvm/Analysis/LoopInfo.h"
82 #include "llvm/Analysis/ScalarEvolutionDivision.h"
83 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
84 #include "llvm/Analysis/TargetLibraryInfo.h"
85 #include "llvm/Analysis/ValueTracking.h"
86 #include "llvm/Config/llvm-config.h"
87 #include "llvm/IR/Argument.h"
88 #include "llvm/IR/BasicBlock.h"
89 #include "llvm/IR/CFG.h"
90 #include "llvm/IR/Constant.h"
91 #include "llvm/IR/ConstantRange.h"
92 #include "llvm/IR/Constants.h"
93 #include "llvm/IR/DataLayout.h"
94 #include "llvm/IR/DerivedTypes.h"
95 #include "llvm/IR/Dominators.h"
96 #include "llvm/IR/Function.h"
97 #include "llvm/IR/GlobalAlias.h"
98 #include "llvm/IR/GlobalValue.h"
99 #include "llvm/IR/GlobalVariable.h"
100 #include "llvm/IR/InstIterator.h"
101 #include "llvm/IR/InstrTypes.h"
102 #include "llvm/IR/Instruction.h"
103 #include "llvm/IR/Instructions.h"
104 #include "llvm/IR/IntrinsicInst.h"
105 #include "llvm/IR/Intrinsics.h"
106 #include "llvm/IR/LLVMContext.h"
107 #include "llvm/IR/Metadata.h"
108 #include "llvm/IR/Operator.h"
109 #include "llvm/IR/PatternMatch.h"
110 #include "llvm/IR/Type.h"
111 #include "llvm/IR/Use.h"
112 #include "llvm/IR/User.h"
113 #include "llvm/IR/Value.h"
114 #include "llvm/IR/Verifier.h"
115 #include "llvm/InitializePasses.h"
116 #include "llvm/Pass.h"
117 #include "llvm/Support/Casting.h"
118 #include "llvm/Support/CommandLine.h"
119 #include "llvm/Support/Compiler.h"
120 #include "llvm/Support/Debug.h"
121 #include "llvm/Support/ErrorHandling.h"
122 #include "llvm/Support/KnownBits.h"
123 #include "llvm/Support/SaveAndRestore.h"
124 #include "llvm/Support/raw_ostream.h"
137 using namespace llvm;
138 using namespace PatternMatch;
140 #define DEBUG_TYPE "scalar-evolution"
142 STATISTIC(NumTripCountsComputed,
143 "Number of loops with predictable loop counts");
144 STATISTIC(NumTripCountsNotComputed,
145 "Number of loops without predictable loop counts");
146 STATISTIC(NumBruteForceTripCountsComputed,
147 "Number of loops with trip counts computed by force");
149 static cl::opt<unsigned>
150 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
152 cl::desc("Maximum number of iterations SCEV will "
153 "symbolically execute a constant "
157 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
158 static cl::opt<bool> VerifySCEV(
159 "verify-scev", cl::Hidden,
160 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
161 static cl::opt<bool> VerifySCEVStrict(
162 "verify-scev-strict", cl::Hidden,
163 cl::desc("Enable stricter verification with -verify-scev is passed"));
165 VerifySCEVMap("verify-scev-maps", cl::Hidden,
166 cl::desc("Verify no dangling value in ScalarEvolution's "
167 "ExprValueMap (slow)"));
169 static cl::opt<bool> VerifyIR(
170 "scev-verify-ir", cl::Hidden,
171 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
174 static cl::opt<unsigned> MulOpsInlineThreshold(
175 "scev-mulops-inline-threshold", cl::Hidden,
176 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
179 static cl::opt<unsigned> AddOpsInlineThreshold(
180 "scev-addops-inline-threshold", cl::Hidden,
181 cl::desc("Threshold for inlining addition operands into a SCEV"),
184 static cl::opt<unsigned> MaxSCEVCompareDepth(
185 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
186 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
189 static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
190 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
191 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
194 static cl::opt<unsigned> MaxValueCompareDepth(
195 "scalar-evolution-max-value-compare-depth", cl::Hidden,
196 cl::desc("Maximum depth of recursive value complexity comparisons"),
199 static cl::opt<unsigned>
200 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
201 cl::desc("Maximum depth of recursive arithmetics"),
204 static cl::opt<unsigned> MaxConstantEvolvingDepth(
205 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
206 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
208 static cl::opt<unsigned>
209 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
210 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
213 static cl::opt<unsigned>
214 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
215 cl::desc("Max coefficients in AddRec during evolving"),
218 static cl::opt<unsigned>
219 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
220 cl::desc("Size of the expression which is considered huge"),
224 ClassifyExpressions("scalar-evolution-classify-expressions",
225 cl::Hidden, cl::init(true),
226 cl::desc("When printing analysis, include information on every instruction"));
228 static cl::opt<bool> UseExpensiveRangeSharpening(
229 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
231 cl::desc("Use more powerful methods of sharpening expression ranges. May "
232 "be costly in terms of compile time"));
234 //===----------------------------------------------------------------------===//
235 // SCEV class definitions
236 //===----------------------------------------------------------------------===//
238 //===----------------------------------------------------------------------===//
239 // Implementation of the SCEV class.
242 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
243 LLVM_DUMP_METHOD void SCEV::dump() const {
249 void SCEV::print(raw_ostream &OS) const {
250 switch (getSCEVType()) {
252 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
255 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
256 const SCEV *Op = PtrToInt->getOperand();
257 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
258 << *PtrToInt->getType() << ")";
262 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
263 const SCEV *Op = Trunc->getOperand();
264 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
265 << *Trunc->getType() << ")";
269 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
270 const SCEV *Op = ZExt->getOperand();
271 OS << "(zext " << *Op->getType() << " " << *Op << " to "
272 << *ZExt->getType() << ")";
276 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
277 const SCEV *Op = SExt->getOperand();
278 OS << "(sext " << *Op->getType() << " " << *Op << " to "
279 << *SExt->getType() << ")";
283 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
284 OS << "{" << *AR->getOperand(0);
285 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
286 OS << ",+," << *AR->getOperand(i);
288 if (AR->hasNoUnsignedWrap())
290 if (AR->hasNoSignedWrap())
292 if (AR->hasNoSelfWrap() &&
293 !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
295 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
305 case scSequentialUMinExpr: {
306 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
307 const char *OpStr = nullptr;
308 switch (NAry->getSCEVType()) {
309 case scAddExpr: OpStr = " + "; break;
310 case scMulExpr: OpStr = " * "; break;
311 case scUMaxExpr: OpStr = " umax "; break;
312 case scSMaxExpr: OpStr = " smax "; break;
319 case scSequentialUMinExpr:
320 OpStr = " umin_seq ";
323 llvm_unreachable("There are no other nary expression types.");
326 ListSeparator LS(OpStr);
327 for (const SCEV *Op : NAry->operands())
330 switch (NAry->getSCEVType()) {
333 if (NAry->hasNoUnsignedWrap())
335 if (NAry->hasNoSignedWrap())
339 // Nothing to print for other nary expressions.
345 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
346 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
350 const SCEVUnknown *U = cast<SCEVUnknown>(this);
352 if (U->isSizeOf(AllocTy)) {
353 OS << "sizeof(" << *AllocTy << ")";
356 if (U->isAlignOf(AllocTy)) {
357 OS << "alignof(" << *AllocTy << ")";
363 if (U->isOffsetOf(CTy, FieldNo)) {
364 OS << "offsetof(" << *CTy << ", ";
365 FieldNo->printAsOperand(OS, false);
370 // Otherwise just print it normally.
371 U->getValue()->printAsOperand(OS, false);
374 case scCouldNotCompute:
375 OS << "***COULDNOTCOMPUTE***";
378 llvm_unreachable("Unknown SCEV kind!");
381 Type *SCEV::getType() const {
382 switch (getSCEVType()) {
384 return cast<SCEVConstant>(this)->getType();
389 return cast<SCEVCastExpr>(this)->getType();
391 return cast<SCEVAddRecExpr>(this)->getType();
393 return cast<SCEVMulExpr>(this)->getType();
398 return cast<SCEVMinMaxExpr>(this)->getType();
399 case scSequentialUMinExpr:
400 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
402 return cast<SCEVAddExpr>(this)->getType();
404 return cast<SCEVUDivExpr>(this)->getType();
406 return cast<SCEVUnknown>(this)->getType();
407 case scCouldNotCompute:
408 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
410 llvm_unreachable("Unknown SCEV kind!");
413 bool SCEV::isZero() const {
414 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
415 return SC->getValue()->isZero();
419 bool SCEV::isOne() const {
420 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
421 return SC->getValue()->isOne();
425 bool SCEV::isAllOnesValue() const {
426 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
427 return SC->getValue()->isMinusOne();
431 bool SCEV::isNonConstantNegative() const {
432 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
433 if (!Mul) return false;
435 // If there is a constant factor, it will be first.
436 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
437 if (!SC) return false;
439 // Return true if the value is negative, this matches things like (-42 * V).
440 return SC->getAPInt().isNegative();
443 SCEVCouldNotCompute::SCEVCouldNotCompute() :
444 SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
446 bool SCEVCouldNotCompute::classof(const SCEV *S) {
447 return S->getSCEVType() == scCouldNotCompute;
450 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
452 ID.AddInteger(scConstant);
455 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
456 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
457 UniqueSCEVs.InsertNode(S, IP);
461 const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
462 return getConstant(ConstantInt::get(getContext(), Val));
466 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
467 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
468 return getConstant(ConstantInt::get(ITy, V, isSigned));
471 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
472 const SCEV *op, Type *ty)
473 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
477 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
479 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
480 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
481 "Must be a non-bit-width-changing pointer-to-integer cast!");
484 SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
485 SCEVTypes SCEVTy, const SCEV *op,
487 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
489 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
491 : SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
492 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
493 "Cannot truncate non-integer value!");
496 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
497 const SCEV *op, Type *ty)
498 : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
499 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
500 "Cannot zero extend non-integer value!");
503 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
504 const SCEV *op, Type *ty)
505 : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
506 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
507 "Cannot sign extend non-integer value!");
510 void SCEVUnknown::deleted() {
511 // Clear this SCEVUnknown from various maps.
512 SE->forgetMemoizedResults(this);
514 // Remove this SCEVUnknown from the uniquing map.
515 SE->UniqueSCEVs.RemoveNode(this);
517 // Release the value.
521 void SCEVUnknown::allUsesReplacedWith(Value *New) {
522 // Remove this SCEVUnknown from the uniquing map.
523 SE->UniqueSCEVs.RemoveNode(this);
525 // Update this SCEVUnknown to point to the new value. This is needed
526 // because there may still be outstanding SCEVs which still point to
531 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
532 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
533 if (VCE->getOpcode() == Instruction::PtrToInt)
534 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
535 if (CE->getOpcode() == Instruction::GetElementPtr &&
536 CE->getOperand(0)->isNullValue() &&
537 CE->getNumOperands() == 2)
538 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
540 AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
547 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
548 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
549 if (VCE->getOpcode() == Instruction::PtrToInt)
550 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
551 if (CE->getOpcode() == Instruction::GetElementPtr &&
552 CE->getOperand(0)->isNullValue()) {
553 Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
554 if (StructType *STy = dyn_cast<StructType>(Ty))
555 if (!STy->isPacked() &&
556 CE->getNumOperands() == 3 &&
557 CE->getOperand(1)->isNullValue()) {
558 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
560 STy->getNumElements() == 2 &&
561 STy->getElementType(0)->isIntegerTy(1)) {
562 AllocTy = STy->getElementType(1);
571 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
572 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
573 if (VCE->getOpcode() == Instruction::PtrToInt)
574 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
575 if (CE->getOpcode() == Instruction::GetElementPtr &&
576 CE->getNumOperands() == 3 &&
577 CE->getOperand(0)->isNullValue() &&
578 CE->getOperand(1)->isNullValue()) {
579 Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
580 // Ignore vector types here so that ScalarEvolutionExpander doesn't
581 // emit getelementptrs that index into vectors.
582 if (Ty->isStructTy() || Ty->isArrayTy()) {
584 FieldNo = CE->getOperand(2);
592 //===----------------------------------------------------------------------===//
594 //===----------------------------------------------------------------------===//
596 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
597 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
598 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
599 /// have been previously deemed to be "equally complex" by this routine. It is
600 /// intended to avoid exponential time complexity in cases like:
610 /// CompareValueComplexity(%f, %c)
612 /// Since we do not continue running this routine on expression trees once we
613 /// have seen unequal values, there is no need to track them in the cache.
615 CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
616 const LoopInfo *const LI, Value *LV, Value *RV,
618 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
621 // Order pointer values after integer values. This helps SCEVExpander form
623 bool LIsPointer = LV->getType()->isPointerTy(),
624 RIsPointer = RV->getType()->isPointerTy();
625 if (LIsPointer != RIsPointer)
626 return (int)LIsPointer - (int)RIsPointer;
628 // Compare getValueID values.
629 unsigned LID = LV->getValueID(), RID = RV->getValueID();
631 return (int)LID - (int)RID;
633 // Sort arguments by their position.
634 if (const auto *LA = dyn_cast<Argument>(LV)) {
635 const auto *RA = cast<Argument>(RV);
636 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
637 return (int)LArgNo - (int)RArgNo;
640 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
641 const auto *RGV = cast<GlobalValue>(RV);
643 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
644 auto LT = GV->getLinkage();
645 return !(GlobalValue::isPrivateLinkage(LT) ||
646 GlobalValue::isInternalLinkage(LT));
649 // Use the names to distinguish the two values, but only if the
650 // names are semantically important.
651 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
652 return LGV->getName().compare(RGV->getName());
655 // For instructions, compare their loop depth, and their operand count. This
657 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
658 const auto *RInst = cast<Instruction>(RV);
660 // Compare loop depths.
661 const BasicBlock *LParent = LInst->getParent(),
662 *RParent = RInst->getParent();
663 if (LParent != RParent) {
664 unsigned LDepth = LI->getLoopDepth(LParent),
665 RDepth = LI->getLoopDepth(RParent);
666 if (LDepth != RDepth)
667 return (int)LDepth - (int)RDepth;
670 // Compare the number of operands.
671 unsigned LNumOps = LInst->getNumOperands(),
672 RNumOps = RInst->getNumOperands();
673 if (LNumOps != RNumOps)
674 return (int)LNumOps - (int)RNumOps;
676 for (unsigned Idx : seq(0u, LNumOps)) {
678 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
679 RInst->getOperand(Idx), Depth + 1);
685 EqCacheValue.unionSets(LV, RV);
689 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
690 // than RHS, respectively. A three-way result allows recursive comparisons to be
692 // If the max analysis depth was reached, return None, assuming we do not know
693 // if they are equivalent for sure.
695 CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
696 EquivalenceClasses<const Value *> &EqCacheValue,
697 const LoopInfo *const LI, const SCEV *LHS,
698 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
699 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
703 // Primarily, sort the SCEVs by their getSCEVType().
704 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
706 return (int)LType - (int)RType;
708 if (EqCacheSCEV.isEquivalent(LHS, RHS))
711 if (Depth > MaxSCEVCompareDepth)
714 // Aside from the getSCEVType() ordering, the particular ordering
715 // isn't very important except that it's beneficial to be consistent,
716 // so that (a + b) and (b + a) don't end up as different expressions.
719 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
720 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
722 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
723 RU->getValue(), Depth + 1);
725 EqCacheSCEV.unionSets(LHS, RHS);
730 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
731 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
733 // Compare constant values.
734 const APInt &LA = LC->getAPInt();
735 const APInt &RA = RC->getAPInt();
736 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
737 if (LBitWidth != RBitWidth)
738 return (int)LBitWidth - (int)RBitWidth;
739 return LA.ult(RA) ? -1 : 1;
743 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
744 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
746 // There is always a dominance between two recs that are used by one SCEV,
747 // so we can safely sort recs by loop header dominance. We require such
748 // order in getAddExpr.
749 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
750 if (LLoop != RLoop) {
751 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
752 assert(LHead != RHead && "Two loops share the same header?");
753 if (DT.dominates(LHead, RHead))
756 assert(DT.dominates(RHead, LHead) &&
757 "No dominance between recurrences used by one SCEV?");
761 // Addrec complexity grows with operand count.
762 unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
763 if (LNumOps != RNumOps)
764 return (int)LNumOps - (int)RNumOps;
766 // Lexicographically compare.
767 for (unsigned i = 0; i != LNumOps; ++i) {
768 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
769 LA->getOperand(i), RA->getOperand(i), DT,
774 EqCacheSCEV.unionSets(LHS, RHS);
784 case scSequentialUMinExpr: {
785 const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
786 const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
788 // Lexicographically compare n-ary expressions.
789 unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
790 if (LNumOps != RNumOps)
791 return (int)LNumOps - (int)RNumOps;
793 for (unsigned i = 0; i != LNumOps; ++i) {
794 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
795 LC->getOperand(i), RC->getOperand(i), DT,
800 EqCacheSCEV.unionSets(LHS, RHS);
805 const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
806 const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
808 // Lexicographically compare udiv expressions.
809 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
810 RC->getLHS(), DT, Depth + 1);
813 X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
814 RC->getRHS(), DT, Depth + 1);
816 EqCacheSCEV.unionSets(LHS, RHS);
824 const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
825 const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
827 // Compare cast expressions by operand.
829 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
830 RC->getOperand(), DT, Depth + 1);
832 EqCacheSCEV.unionSets(LHS, RHS);
836 case scCouldNotCompute:
837 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
839 llvm_unreachable("Unknown SCEV kind!");
842 /// Given a list of SCEV objects, order them by their complexity, and group
843 /// objects of the same complexity together by value. When this routine is
844 /// finished, we know that any duplicates in the vector are consecutive and that
845 /// complexity is monotonically increasing.
847 /// Note that we go take special precautions to ensure that we get deterministic
848 /// results from this routine. In other words, we don't want the results of
849 /// this to depend on where the addresses of various SCEV objects happened to
851 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
852 LoopInfo *LI, DominatorTree &DT) {
853 if (Ops.size() < 2) return; // Noop
855 EquivalenceClasses<const SCEV *> EqCacheSCEV;
856 EquivalenceClasses<const Value *> EqCacheValue;
858 // Whether LHS has provably less complexity than RHS.
859 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
861 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
862 return Complexity && *Complexity < 0;
864 if (Ops.size() == 2) {
865 // This is the common case, which also happens to be trivially simple.
867 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
868 if (IsLessComplex(RHS, LHS))
873 // Do the rough sort by complexity.
874 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
875 return IsLessComplex(LHS, RHS);
878 // Now that we are sorted by complexity, group elements of the same
879 // complexity. Note that this is, at worst, N^2, but the vector is likely to
880 // be extremely short in practice. Note that we take this approach because we
881 // do not want to depend on the addresses of the objects we are grouping.
882 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
883 const SCEV *S = Ops[i];
884 unsigned Complexity = S->getSCEVType();
886 // If there are any objects of the same complexity and same value as this
888 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
889 if (Ops[j] == S) { // Found a duplicate.
890 // Move it to immediately after i'th element.
891 std::swap(Ops[i+1], Ops[j]);
892 ++i; // no need to rescan it.
893 if (i == e-2) return; // Done!
899 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
900 /// least HugeExprThreshold nodes).
901 static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
902 return any_of(Ops, [](const SCEV *S) {
903 return S->getExpressionSize() >= HugeExprThreshold;
907 //===----------------------------------------------------------------------===//
908 // Simple SCEV method implementations
909 //===----------------------------------------------------------------------===//
911 /// Compute BC(It, K). The result has width W. Assume, K > 0.
912 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
915 // Handle the simplest case efficiently.
917 return SE.getTruncateOrZeroExtend(It, ResultTy);
919 // We are using the following formula for BC(It, K):
921 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
923 // Suppose, W is the bitwidth of the return value. We must be prepared for
924 // overflow. Hence, we must assure that the result of our computation is
925 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
926 // safe in modular arithmetic.
928 // However, this code doesn't use exactly that formula; the formula it uses
929 // is something like the following, where T is the number of factors of 2 in
930 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
933 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
935 // This formula is trivially equivalent to the previous formula. However,
936 // this formula can be implemented much more efficiently. The trick is that
937 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
938 // arithmetic. To do exact division in modular arithmetic, all we have
939 // to do is multiply by the inverse. Therefore, this step can be done at
942 // The next issue is how to safely do the division by 2^T. The way this
943 // is done is by doing the multiplication step at a width of at least W + T
944 // bits. This way, the bottom W+T bits of the product are accurate. Then,
945 // when we perform the division by 2^T (which is equivalent to a right shift
946 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
947 // truncated out after the division by 2^T.
949 // In comparison to just directly using the first formula, this technique
950 // is much more efficient; using the first formula requires W * K bits,
951 // but this formula less than W + K bits. Also, the first formula requires
952 // a division step, whereas this formula only requires multiplies and shifts.
954 // It doesn't matter whether the subtraction step is done in the calculation
955 // width or the input iteration count's width; if the subtraction overflows,
956 // the result must be zero anyway. We prefer here to do it in the width of
957 // the induction variable because it helps a lot for certain cases; CodeGen
958 // isn't smart enough to ignore the overflow, which leads to much less
959 // efficient code if the width of the subtraction is wider than the native
962 // (It's possible to not widen at all by pulling out factors of 2 before
963 // the multiplication; for example, K=2 can be calculated as
964 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
965 // extra arithmetic, so it's not an obvious win, and it gets
966 // much more complicated for K > 3.)
968 // Protection from insane SCEVs; this bound is conservative,
969 // but it probably doesn't matter.
971 return SE.getCouldNotCompute();
973 unsigned W = SE.getTypeSizeInBits(ResultTy);
975 // Calculate K! / 2^T and T; we divide out the factors of two before
976 // multiplying for calculating K! / 2^T to avoid overflow.
977 // Other overflow doesn't matter because we only care about the bottom
978 // W bits of the result.
979 APInt OddFactorial(W, 1);
981 for (unsigned i = 3; i <= K; ++i) {
983 unsigned TwoFactors = Mult.countTrailingZeros();
985 Mult.lshrInPlace(TwoFactors);
986 OddFactorial *= Mult;
989 // We need at least W + T bits for the multiplication step
990 unsigned CalculationBits = W + T;
992 // Calculate 2^T, at width T+W.
993 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
995 // Calculate the multiplicative inverse of K! / 2^T;
996 // this multiplication factor will perform the exact division by
998 APInt Mod = APInt::getSignedMinValue(W+1);
999 APInt MultiplyFactor = OddFactorial.zext(W+1);
1000 MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1001 MultiplyFactor = MultiplyFactor.trunc(W);
1003 // Calculate the product, at width T+W
1004 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1006 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1007 for (unsigned i = 1; i != K; ++i) {
1008 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1009 Dividend = SE.getMulExpr(Dividend,
1010 SE.getTruncateOrZeroExtend(S, CalculationTy));
1014 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1016 // Truncate the result, and divide by K! / 2^T.
1018 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1019 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1022 /// Return the value of this chain of recurrences at the specified iteration
1023 /// number. We can evaluate this recurrence by multiplying each element in the
1024 /// chain by the binomial coefficient corresponding to it. In other words, we
1025 /// can evaluate {A,+,B,+,C,+,D} as:
1027 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1029 /// where BC(It, k) stands for binomial coefficient.
1030 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
1031 ScalarEvolution &SE) const {
1032 return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1036 SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands,
1037 const SCEV *It, ScalarEvolution &SE) {
1038 assert(Operands.size() > 0);
1039 const SCEV *Result = Operands[0];
1040 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1041 // The computation is correct in the face of overflow provided that the
1042 // multiplication is performed _after_ the evaluation of the binomial
1044 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1045 if (isa<SCEVCouldNotCompute>(Coeff))
1048 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1053 //===----------------------------------------------------------------------===//
1054 // SCEV Expression folder implementations
1055 //===----------------------------------------------------------------------===//
1057 const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
1059 assert(Depth <= 1 &&
1060 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1062 // We could be called with an integer-typed operands during SCEV rewrites.
1063 // Since the operand is an integer already, just perform zext/trunc/self cast.
1064 if (!Op->getType()->isPointerTy())
1067 // What would be an ID for such a SCEV cast expression?
1068 FoldingSetNodeID ID;
1069 ID.AddInteger(scPtrToInt);
1074 // Is there already an expression for such a cast?
1075 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1078 // It isn't legal for optimizations to construct new ptrtoint expressions
1079 // for non-integral pointers.
1080 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1081 return getCouldNotCompute();
1083 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1085 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1086 // is sufficiently wide to represent all possible pointer values.
1087 // We could theoretically teach SCEV to truncate wider pointers, but
1088 // that isn't implemented for now.
1089 if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(Op->getType())) !=
1090 getDataLayout().getTypeSizeInBits(IntPtrTy))
1091 return getCouldNotCompute();
1093 // If not, is this expression something we can't reduce any further?
1094 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1095 // Perform some basic constant folding. If the operand of the ptr2int cast
1096 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1097 // left as-is), but produce a zero constant.
1098 // NOTE: We could handle a more general case, but lack motivational cases.
1099 if (isa<ConstantPointerNull>(U->getValue()))
1100 return getZero(IntPtrTy);
1102 // Create an explicit cast node.
1103 // We can reuse the existing insert position since if we get here,
1104 // we won't have made any changes which would invalidate it.
1105 SCEV *S = new (SCEVAllocator)
1106 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1107 UniqueSCEVs.InsertNode(S, IP);
1108 registerUser(S, Op);
1112 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1113 "non-SCEVUnknown's.");
1115 // Otherwise, we've got some expression that is more complex than just a
1116 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1117 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1118 // only, and the expressions must otherwise be integer-typed.
1119 // So sink the cast down to the SCEVUnknown's.
1121 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1122 /// which computes a pointer-typed value, and rewrites the whole expression
1123 /// tree so that *all* the computations are done on integers, and the only
1124 /// pointer-typed operands in the expression are SCEVUnknown.
1125 class SCEVPtrToIntSinkingRewriter
1126 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1127 using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
1130 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1132 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1133 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1134 return Rewriter.visit(Scev);
1137 const SCEV *visit(const SCEV *S) {
1138 Type *STy = S->getType();
1139 // If the expression is not pointer-typed, just keep it as-is.
1140 if (!STy->isPointerTy())
1142 // Else, recursively sink the cast down into it.
1143 return Base::visit(S);
1146 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1147 SmallVector<const SCEV *, 2> Operands;
1148 bool Changed = false;
1149 for (auto *Op : Expr->operands()) {
1150 Operands.push_back(visit(Op));
1151 Changed |= Op != Operands.back();
1153 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1156 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1157 SmallVector<const SCEV *, 2> Operands;
1158 bool Changed = false;
1159 for (auto *Op : Expr->operands()) {
1160 Operands.push_back(visit(Op));
1161 Changed |= Op != Operands.back();
1163 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1166 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1167 assert(Expr->getType()->isPointerTy() &&
1168 "Should only reach pointer-typed SCEVUnknown's.");
1169 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1173 // And actually perform the cast sinking.
1174 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1175 assert(IntOp->getType()->isIntegerTy() &&
1176 "We must have succeeded in sinking the cast, "
1177 "and ending up with an integer-typed expression!");
1181 const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) {
1182 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1184 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1185 if (isa<SCEVCouldNotCompute>(IntOp))
1188 return getTruncateOrZeroExtend(IntOp, Ty);
1191 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
1193 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1194 "This is not a truncating conversion!");
1195 assert(isSCEVable(Ty) &&
1196 "This is not a conversion to a SCEVable type!");
1197 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1198 Ty = getEffectiveSCEVType(Ty);
1200 FoldingSetNodeID ID;
1201 ID.AddInteger(scTruncate);
1205 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1207 // Fold if the operand is constant.
1208 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1210 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1212 // trunc(trunc(x)) --> trunc(x)
1213 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1214 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1216 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1217 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1218 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1220 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1221 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1222 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1224 if (Depth > MaxCastDepth) {
1226 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1227 UniqueSCEVs.InsertNode(S, IP);
1228 registerUser(S, Op);
1232 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1233 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1234 // if after transforming we have at most one truncate, not counting truncates
1235 // that replace other casts.
1236 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1237 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1238 SmallVector<const SCEV *, 4> Operands;
1239 unsigned numTruncs = 0;
1240 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1242 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1243 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1244 isa<SCEVTruncateExpr>(S))
1246 Operands.push_back(S);
1248 if (numTruncs < 2) {
1249 if (isa<SCEVAddExpr>(Op))
1250 return getAddExpr(Operands);
1251 else if (isa<SCEVMulExpr>(Op))
1252 return getMulExpr(Operands);
1254 llvm_unreachable("Unexpected SCEV type for Op.");
1256 // Although we checked in the beginning that ID is not in the cache, it is
1257 // possible that during recursion and different modification ID was inserted
1258 // into the cache. So if we find it, just return it.
1259 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1263 // If the input value is a chrec scev, truncate the chrec's operands.
1264 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1265 SmallVector<const SCEV *, 4> Operands;
1266 for (const SCEV *Op : AddRec->operands())
1267 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1268 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1271 // Return zero if truncating to known zeros.
1272 uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1273 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1276 // The cast wasn't folded; create an explicit cast node. We can reuse
1277 // the existing insert position since if we get here, we won't have
1278 // made any changes which would invalidate it.
1279 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1281 UniqueSCEVs.InsertNode(S, IP);
1282 registerUser(S, Op);
1286 // Get the limit of a recurrence such that incrementing by Step cannot cause
1287 // signed overflow as long as the value of the recurrence within the
1288 // loop does not exceed this limit before incrementing.
1289 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1290 ICmpInst::Predicate *Pred,
1291 ScalarEvolution *SE) {
1292 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1293 if (SE->isKnownPositive(Step)) {
1294 *Pred = ICmpInst::ICMP_SLT;
1295 return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1296 SE->getSignedRangeMax(Step));
1298 if (SE->isKnownNegative(Step)) {
1299 *Pred = ICmpInst::ICMP_SGT;
1300 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1301 SE->getSignedRangeMin(Step));
1306 // Get the limit of a recurrence such that incrementing by Step cannot cause
1307 // unsigned overflow as long as the value of the recurrence within the loop does
1308 // not exceed this limit before incrementing.
1309 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1310 ICmpInst::Predicate *Pred,
1311 ScalarEvolution *SE) {
1312 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1313 *Pred = ICmpInst::ICMP_ULT;
1315 return SE->getConstant(APInt::getMinValue(BitWidth) -
1316 SE->getUnsignedRangeMax(Step));
1321 struct ExtendOpTraitsBase {
1322 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1326 // Used to make code generic over signed and unsigned overflow.
1327 template <typename ExtendOp> struct ExtendOpTraits {
1330 // static const SCEV::NoWrapFlags WrapType;
1332 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1334 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1335 // ICmpInst::Predicate *Pred,
1336 // ScalarEvolution *SE);
1340 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1341 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1343 static const GetExtendExprTy GetExtendExpr;
1345 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1346 ICmpInst::Predicate *Pred,
1347 ScalarEvolution *SE) {
1348 return getSignedOverflowLimitForStep(Step, Pred, SE);
1352 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1353 SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1356 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1357 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1359 static const GetExtendExprTy GetExtendExpr;
1361 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1362 ICmpInst::Predicate *Pred,
1363 ScalarEvolution *SE) {
1364 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1368 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1369 SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1371 } // end anonymous namespace
1373 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1374 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1375 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1376 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1377 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1378 // expression "Step + sext/zext(PreIncAR)" is congruent with
1379 // "sext/zext(PostIncAR)"
1380 template <typename ExtendOpTy>
1381 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1382 ScalarEvolution *SE, unsigned Depth) {
1383 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1384 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1386 const Loop *L = AR->getLoop();
1387 const SCEV *Start = AR->getStart();
1388 const SCEV *Step = AR->getStepRecurrence(*SE);
1390 // Check for a simple looking step prior to loop entry.
1391 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1395 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1396 // subtraction is expensive. For this purpose, perform a quick and dirty
1397 // difference, by checking for Step in the operand list.
1398 SmallVector<const SCEV *, 4> DiffOps;
1399 for (const SCEV *Op : SA->operands())
1401 DiffOps.push_back(Op);
1403 if (DiffOps.size() == SA->getNumOperands())
1406 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1409 // 1. NSW/NUW flags on the step increment.
1410 auto PreStartFlags =
1411 ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
1412 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1413 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1414 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1416 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1417 // "S+X does not sign/unsign-overflow".
1420 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1421 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1422 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1425 // 2. Direct overflow check on the step operation's expression.
1426 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1427 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1428 const SCEV *OperandExtendedStart =
1429 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1430 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1431 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1432 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1433 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1434 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1435 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1436 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1441 // 3. Loop precondition.
1442 ICmpInst::Predicate Pred;
1443 const SCEV *OverflowLimit =
1444 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1446 if (OverflowLimit &&
1447 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1453 // Get the normalized zero or sign extended expression for this AddRec's Start.
1454 template <typename ExtendOpTy>
1455 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1456 ScalarEvolution *SE,
1458 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1460 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1462 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1464 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1466 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1469 // Try to prove away overflow by looking at "nearby" add recurrences. A
1470 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1471 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1475 // {S,+,X} == {S-T,+,X} + T
1476 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1478 // If ({S-T,+,X} + T) does not overflow ... (1)
1480 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1482 // If {S-T,+,X} does not overflow ... (2)
1484 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1485 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1487 // If (S-T)+T does not overflow ... (3)
1489 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1490 // == {Ext(S),+,Ext(X)} == LHS
1492 // Thus, if (1), (2) and (3) are true for some T, then
1493 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1495 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1496 // does not overflow" restricted to the 0th iteration. Therefore we only need
1497 // to check for (1) and (2).
1499 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1500 // is `Delta` (defined below).
1501 template <typename ExtendOpTy>
1502 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1505 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1507 // We restrict `Start` to a constant to prevent SCEV from spending too much
1508 // time here. It is correct (but more expensive) to continue with a
1509 // non-constant `Start` and do a general SCEV subtraction to compute
1510 // `PreStart` below.
1511 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1515 APInt StartAI = StartC->getAPInt();
1517 for (unsigned Delta : {-2, -1, 1, 2}) {
1518 const SCEV *PreStart = getConstant(StartAI - Delta);
1520 FoldingSetNodeID ID;
1521 ID.AddInteger(scAddRecExpr);
1522 ID.AddPointer(PreStart);
1523 ID.AddPointer(Step);
1527 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1529 // Give up if we don't already have the add recurrence we need because
1530 // actually constructing an add recurrence is relatively expensive.
1531 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1532 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1533 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
1534 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1535 DeltaS, &Pred, this);
1536 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1544 // Finds an integer D for an expression (C + x + y + ...) such that the top
1545 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1546 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1547 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1548 // the (C + x + y + ...) expression is \p WholeAddExpr.
1549 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1550 const SCEVConstant *ConstantTerm,
1551 const SCEVAddExpr *WholeAddExpr) {
1552 const APInt &C = ConstantTerm->getAPInt();
1553 const unsigned BitWidth = C.getBitWidth();
1554 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1555 uint32_t TZ = BitWidth;
1556 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1557 TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1559 // Set D to be as many least significant bits of C as possible while still
1560 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1561 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1563 return APInt(BitWidth, 0);
1566 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1567 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1568 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1569 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1570 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1571 const APInt &ConstantStart,
1573 const unsigned BitWidth = ConstantStart.getBitWidth();
1574 const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1576 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1578 return APInt(BitWidth, 0);
1582 ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1583 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1584 "This is not an extending conversion!");
1585 assert(isSCEVable(Ty) &&
1586 "This is not a conversion to a SCEVable type!");
1587 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1588 Ty = getEffectiveSCEVType(Ty);
1590 // Fold if the operand is constant.
1591 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1593 cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1595 // zext(zext(x)) --> zext(x)
1596 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1597 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1599 // Before doing any expensive analysis, check to see if we've already
1600 // computed a SCEV for this Op and Ty.
1601 FoldingSetNodeID ID;
1602 ID.AddInteger(scZeroExtend);
1606 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1607 if (Depth > MaxCastDepth) {
1608 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1610 UniqueSCEVs.InsertNode(S, IP);
1611 registerUser(S, Op);
1615 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1616 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1617 // It's possible the bits taken off by the truncate were all zero bits. If
1618 // so, we should be able to simplify this further.
1619 const SCEV *X = ST->getOperand();
1620 ConstantRange CR = getUnsignedRange(X);
1621 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1622 unsigned NewBits = getTypeSizeInBits(Ty);
1623 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1624 CR.zextOrTrunc(NewBits)))
1625 return getTruncateOrZeroExtend(X, Ty, Depth);
1628 // If the input value is a chrec scev, and we can prove that the value
1629 // did not overflow the old, smaller, value, we can zero extend all of the
1630 // operands (often constants). This allows analysis of something like
1631 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1632 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1633 if (AR->isAffine()) {
1634 const SCEV *Start = AR->getStart();
1635 const SCEV *Step = AR->getStepRecurrence(*this);
1636 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1637 const Loop *L = AR->getLoop();
1639 if (!AR->hasNoUnsignedWrap()) {
1640 auto NewFlags = proveNoWrapViaConstantRanges(AR);
1641 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1644 // If we have special knowledge that this addrec won't overflow,
1645 // we don't need to do any further analysis.
1646 if (AR->hasNoUnsignedWrap())
1647 return getAddRecExpr(
1648 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1649 getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1651 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1652 // Note that this serves two purposes: It filters out loops that are
1653 // simply not analyzable, and it covers the case where this code is
1654 // being called from within backedge-taken count analysis, such that
1655 // attempting to ask for the backedge-taken count would likely result
1656 // in infinite recursion. In the later case, the analysis code will
1657 // cope with a conservative value, and it will take care to purge
1658 // that value once it has finished.
1659 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1660 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1661 // Manually compute the final value for AR, checking for overflow.
1663 // Check whether the backedge-taken count can be losslessly casted to
1664 // the addrec's type. The count is always unsigned.
1665 const SCEV *CastedMaxBECount =
1666 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1667 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1668 CastedMaxBECount, MaxBECount->getType(), Depth);
1669 if (MaxBECount == RecastedMaxBECount) {
1670 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1671 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1672 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1673 SCEV::FlagAnyWrap, Depth + 1);
1674 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1678 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1679 const SCEV *WideMaxBECount =
1680 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1681 const SCEV *OperandExtendedAdd =
1682 getAddExpr(WideStart,
1683 getMulExpr(WideMaxBECount,
1684 getZeroExtendExpr(Step, WideTy, Depth + 1),
1685 SCEV::FlagAnyWrap, Depth + 1),
1686 SCEV::FlagAnyWrap, Depth + 1);
1687 if (ZAdd == OperandExtendedAdd) {
1688 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1689 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1690 // Return the expression with the addrec on the outside.
1691 return getAddRecExpr(
1692 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1694 getZeroExtendExpr(Step, Ty, Depth + 1), L,
1695 AR->getNoWrapFlags());
1697 // Similar to above, only this time treat the step value as signed.
1698 // This covers loops that count down.
1699 OperandExtendedAdd =
1700 getAddExpr(WideStart,
1701 getMulExpr(WideMaxBECount,
1702 getSignExtendExpr(Step, WideTy, Depth + 1),
1703 SCEV::FlagAnyWrap, Depth + 1),
1704 SCEV::FlagAnyWrap, Depth + 1);
1705 if (ZAdd == OperandExtendedAdd) {
1706 // Cache knowledge of AR NW, which is propagated to this AddRec.
1707 // Negative step causes unsigned wrap, but it still can't self-wrap.
1708 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1709 // Return the expression with the addrec on the outside.
1710 return getAddRecExpr(
1711 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1713 getSignExtendExpr(Step, Ty, Depth + 1), L,
1714 AR->getNoWrapFlags());
1719 // Normally, in the cases we can prove no-overflow via a
1720 // backedge guarding condition, we can also compute a backedge
1721 // taken count for the loop. The exceptions are assumptions and
1722 // guards present in the loop -- SCEV is not great at exploiting
1723 // these to compute max backedge taken counts, but can still use
1724 // these to prove lack of overflow. Use this fact to avoid
1725 // doing extra work that may not pay off.
1726 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1727 !AC.assumptions().empty()) {
1729 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1730 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1731 if (AR->hasNoUnsignedWrap()) {
1732 // Same as nuw case above - duplicated here to avoid a compile time
1733 // issue. It's not clear that the order of checks does matter, but
1734 // it's one of two issue possible causes for a change which was
1735 // reverted. Be conservative for the moment.
1736 return getAddRecExpr(
1737 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1739 getZeroExtendExpr(Step, Ty, Depth + 1), L,
1740 AR->getNoWrapFlags());
1743 // For a negative step, we can extend the operands iff doing so only
1744 // traverses values in the range zext([0,UINT_MAX]).
1745 if (isKnownNegative(Step)) {
1746 const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1747 getSignedRangeMin(Step));
1748 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1749 isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
1750 // Cache knowledge of AR NW, which is propagated to this
1751 // AddRec. Negative step causes unsigned wrap, but it
1752 // still can't self-wrap.
1753 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1754 // Return the expression with the addrec on the outside.
1755 return getAddRecExpr(
1756 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1758 getSignExtendExpr(Step, Ty, Depth + 1), L,
1759 AR->getNoWrapFlags());
1764 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1765 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1766 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1767 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1768 const APInt &C = SC->getAPInt();
1769 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1771 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1772 const SCEV *SResidual =
1773 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1774 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1775 return getAddExpr(SZExtD, SZExtR,
1776 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1781 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1782 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1783 return getAddRecExpr(
1784 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1785 getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1789 // zext(A % B) --> zext(A) % zext(B)
1793 if (matchURem(Op, LHS, RHS))
1794 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1795 getZeroExtendExpr(RHS, Ty, Depth + 1));
1798 // zext(A / B) --> zext(A) / zext(B).
1799 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1800 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1801 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1803 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1804 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1805 if (SA->hasNoUnsignedWrap()) {
1806 // If the addition does not unsign overflow then we can, by definition,
1807 // commute the zero extension with the addition operation.
1808 SmallVector<const SCEV *, 4> Ops;
1809 for (const auto *Op : SA->operands())
1810 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1811 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1814 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1815 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1816 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1818 // Often address arithmetics contain expressions like
1819 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1820 // This transformation is useful while proving that such expressions are
1821 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1822 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1823 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1825 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1826 const SCEV *SResidual =
1827 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1828 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1829 return getAddExpr(SZExtD, SZExtR,
1830 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1836 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1837 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1838 if (SM->hasNoUnsignedWrap()) {
1839 // If the multiply does not unsign overflow then we can, by definition,
1840 // commute the zero extension with the multiply operation.
1841 SmallVector<const SCEV *, 4> Ops;
1842 for (const auto *Op : SM->operands())
1843 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1844 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1847 // zext(2^K * (trunc X to iN)) to iM ->
1848 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1852 // zext(2^K * (trunc X to iN)) to iM
1853 // = zext((trunc X to iN) << K) to iM
1854 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1855 // (because shl removes the top K bits)
1856 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1857 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1859 if (SM->getNumOperands() == 2)
1860 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1861 if (MulLHS->getAPInt().isPowerOf2())
1862 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1863 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1864 MulLHS->getAPInt().logBase2();
1865 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1867 getZeroExtendExpr(MulLHS, Ty),
1869 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1870 SCEV::FlagNUW, Depth + 1);
1874 // The cast wasn't folded; create an explicit cast node.
1875 // Recompute the insert position, as it may have been invalidated.
1876 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1877 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1879 UniqueSCEVs.InsertNode(S, IP);
1880 registerUser(S, Op);
1885 ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1886 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1887 "This is not an extending conversion!");
1888 assert(isSCEVable(Ty) &&
1889 "This is not a conversion to a SCEVable type!");
1890 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1891 Ty = getEffectiveSCEVType(Ty);
1893 // Fold if the operand is constant.
1894 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1896 cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1898 // sext(sext(x)) --> sext(x)
1899 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1900 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1902 // sext(zext(x)) --> zext(x)
1903 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1904 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1906 // Before doing any expensive analysis, check to see if we've already
1907 // computed a SCEV for this Op and Ty.
1908 FoldingSetNodeID ID;
1909 ID.AddInteger(scSignExtend);
1913 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1914 // Limit recursion depth.
1915 if (Depth > MaxCastDepth) {
1916 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1918 UniqueSCEVs.InsertNode(S, IP);
1919 registerUser(S, Op);
1923 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1924 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1925 // It's possible the bits taken off by the truncate were all sign bits. If
1926 // so, we should be able to simplify this further.
1927 const SCEV *X = ST->getOperand();
1928 ConstantRange CR = getSignedRange(X);
1929 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1930 unsigned NewBits = getTypeSizeInBits(Ty);
1931 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1932 CR.sextOrTrunc(NewBits)))
1933 return getTruncateOrSignExtend(X, Ty, Depth);
1936 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1937 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1938 if (SA->hasNoSignedWrap()) {
1939 // If the addition does not sign overflow then we can, by definition,
1940 // commute the sign extension with the addition operation.
1941 SmallVector<const SCEV *, 4> Ops;
1942 for (const auto *Op : SA->operands())
1943 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1944 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1947 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1948 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1949 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1951 // For instance, this will bring two seemingly different expressions:
1952 // 1 + sext(5 + 20 * %x + 24 * %y) and
1953 // sext(6 + 20 * %x + 24 * %y)
1954 // to the same form:
1955 // 2 + sext(4 + 20 * %x + 24 * %y)
1956 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1957 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1959 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1960 const SCEV *SResidual =
1961 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1962 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1963 return getAddExpr(SSExtD, SSExtR,
1964 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1969 // If the input value is a chrec scev, and we can prove that the value
1970 // did not overflow the old, smaller, value, we can sign extend all of the
1971 // operands (often constants). This allows analysis of something like
1972 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1973 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1974 if (AR->isAffine()) {
1975 const SCEV *Start = AR->getStart();
1976 const SCEV *Step = AR->getStepRecurrence(*this);
1977 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1978 const Loop *L = AR->getLoop();
1980 if (!AR->hasNoSignedWrap()) {
1981 auto NewFlags = proveNoWrapViaConstantRanges(AR);
1982 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1985 // If we have special knowledge that this addrec won't overflow,
1986 // we don't need to do any further analysis.
1987 if (AR->hasNoSignedWrap())
1988 return getAddRecExpr(
1989 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1990 getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
1992 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1993 // Note that this serves two purposes: It filters out loops that are
1994 // simply not analyzable, and it covers the case where this code is
1995 // being called from within backedge-taken count analysis, such that
1996 // attempting to ask for the backedge-taken count would likely result
1997 // in infinite recursion. In the later case, the analysis code will
1998 // cope with a conservative value, and it will take care to purge
1999 // that value once it has finished.
2000 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2001 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2002 // Manually compute the final value for AR, checking for
2005 // Check whether the backedge-taken count can be losslessly casted to
2006 // the addrec's type. The count is always unsigned.
2007 const SCEV *CastedMaxBECount =
2008 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2009 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2010 CastedMaxBECount, MaxBECount->getType(), Depth);
2011 if (MaxBECount == RecastedMaxBECount) {
2012 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2013 // Check whether Start+Step*MaxBECount has no signed overflow.
2014 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2015 SCEV::FlagAnyWrap, Depth + 1);
2016 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2020 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2021 const SCEV *WideMaxBECount =
2022 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2023 const SCEV *OperandExtendedAdd =
2024 getAddExpr(WideStart,
2025 getMulExpr(WideMaxBECount,
2026 getSignExtendExpr(Step, WideTy, Depth + 1),
2027 SCEV::FlagAnyWrap, Depth + 1),
2028 SCEV::FlagAnyWrap, Depth + 1);
2029 if (SAdd == OperandExtendedAdd) {
2030 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2031 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2032 // Return the expression with the addrec on the outside.
2033 return getAddRecExpr(
2034 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2036 getSignExtendExpr(Step, Ty, Depth + 1), L,
2037 AR->getNoWrapFlags());
2039 // Similar to above, only this time treat the step value as unsigned.
2040 // This covers loops that count up with an unsigned step.
2041 OperandExtendedAdd =
2042 getAddExpr(WideStart,
2043 getMulExpr(WideMaxBECount,
2044 getZeroExtendExpr(Step, WideTy, Depth + 1),
2045 SCEV::FlagAnyWrap, Depth + 1),
2046 SCEV::FlagAnyWrap, Depth + 1);
2047 if (SAdd == OperandExtendedAdd) {
2048 // If AR wraps around then
2050 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2051 // => SAdd != OperandExtendedAdd
2053 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2054 // (SAdd == OperandExtendedAdd => AR is NW)
2056 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2058 // Return the expression with the addrec on the outside.
2059 return getAddRecExpr(
2060 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2062 getZeroExtendExpr(Step, Ty, Depth + 1), L,
2063 AR->getNoWrapFlags());
2068 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2069 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2070 if (AR->hasNoSignedWrap()) {
2071 // Same as nsw case above - duplicated here to avoid a compile time
2072 // issue. It's not clear that the order of checks does matter, but
2073 // it's one of two issue possible causes for a change which was
2074 // reverted. Be conservative for the moment.
2075 return getAddRecExpr(
2076 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2077 getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2080 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2081 // if D + (C - D + Step * n) could be proven to not signed wrap
2082 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2083 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2084 const APInt &C = SC->getAPInt();
2085 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2087 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2088 const SCEV *SResidual =
2089 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2090 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2091 return getAddExpr(SSExtD, SSExtR,
2092 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
2097 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2098 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2099 return getAddRecExpr(
2100 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2101 getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2105 // If the input value is provably positive and we could not simplify
2106 // away the sext build a zext instead.
2107 if (isKnownNonNegative(Op))
2108 return getZeroExtendExpr(Op, Ty, Depth + 1);
2110 // The cast wasn't folded; create an explicit cast node.
2111 // Recompute the insert position, as it may have been invalidated.
2112 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2113 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2115 UniqueSCEVs.InsertNode(S, IP);
2116 registerUser(S, { Op });
2120 const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op,
2124 return getTruncateExpr(Op, Ty);
2126 return getZeroExtendExpr(Op, Ty);
2128 return getSignExtendExpr(Op, Ty);
2130 return getPtrToIntExpr(Op, Ty);
2132 llvm_unreachable("Not a SCEV cast expression!");
2136 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2137 /// unspecified bits out to the given type.
2138 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
2140 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2141 "This is not an extending conversion!");
2142 assert(isSCEVable(Ty) &&
2143 "This is not a conversion to a SCEVable type!");
2144 Ty = getEffectiveSCEVType(Ty);
2146 // Sign-extend negative constants.
2147 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2148 if (SC->getAPInt().isNegative())
2149 return getSignExtendExpr(Op, Ty);
2151 // Peel off a truncate cast.
2152 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2153 const SCEV *NewOp = T->getOperand();
2154 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2155 return getAnyExtendExpr(NewOp, Ty);
2156 return getTruncateOrNoop(NewOp, Ty);
2159 // Next try a zext cast. If the cast is folded, use it.
2160 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2161 if (!isa<SCEVZeroExtendExpr>(ZExt))
2164 // Next try a sext cast. If the cast is folded, use it.
2165 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2166 if (!isa<SCEVSignExtendExpr>(SExt))
2169 // Force the cast to be folded into the operands of an addrec.
2170 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2171 SmallVector<const SCEV *, 4> Ops;
2172 for (const SCEV *Op : AR->operands())
2173 Ops.push_back(getAnyExtendExpr(Op, Ty));
2174 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2177 // If the expression is obviously signed, use the sext cast value.
2178 if (isa<SCEVSMaxExpr>(Op))
2181 // Absent any other information, use the zext cast value.
2185 /// Process the given Ops list, which is a list of operands to be added under
2186 /// the given scale, update the given map. This is a helper function for
2187 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2188 /// that would form an add expression like this:
2190 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2192 /// where A and B are constants, update the map with these values:
2194 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2196 /// and add 13 + A*B*29 to AccumulatedConstant.
2197 /// This will allow getAddRecExpr to produce this:
2199 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2201 /// This form often exposes folding opportunities that are hidden in
2202 /// the original operand list.
2204 /// Return true iff it appears that any interesting folding opportunities
2205 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2206 /// the common case where no interesting opportunities are present, and
2207 /// is also used as a check to avoid infinite recursion.
2209 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
2210 SmallVectorImpl<const SCEV *> &NewOps,
2211 APInt &AccumulatedConstant,
2212 const SCEV *const *Ops, size_t NumOperands,
2214 ScalarEvolution &SE) {
2215 bool Interesting = false;
2217 // Iterate over the add operands. They are sorted, with constants first.
2219 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2221 // Pull a buried constant out to the outside.
2222 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2224 AccumulatedConstant += Scale * C->getAPInt();
2227 // Next comes everything else. We're especially interested in multiplies
2228 // here, but they're in the middle, so just visit the rest with one loop.
2229 for (; i != NumOperands; ++i) {
2230 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2231 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2233 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2234 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2235 // A multiplication of a constant with another add; recurse.
2236 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2238 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2239 Add->op_begin(), Add->getNumOperands(),
2242 // A multiplication of a constant with some other value. Update
2244 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2245 const SCEV *Key = SE.getMulExpr(MulOps);
2246 auto Pair = M.insert({Key, NewScale});
2248 NewOps.push_back(Pair.first->first);
2250 Pair.first->second += NewScale;
2251 // The map already had an entry for this value, which may indicate
2252 // a folding opportunity.
2257 // An ordinary operand. Update the map.
2258 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2259 M.insert({Ops[i], Scale});
2261 NewOps.push_back(Pair.first->first);
2263 Pair.first->second += Scale;
2264 // The map already had an entry for this value, which may indicate
2265 // a folding opportunity.
2274 bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2275 const SCEV *LHS, const SCEV *RHS) {
2276 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2277 SCEV::NoWrapFlags, unsigned);
2280 llvm_unreachable("Unsupported binary op");
2281 case Instruction::Add:
2282 Operation = &ScalarEvolution::getAddExpr;
2284 case Instruction::Sub:
2285 Operation = &ScalarEvolution::getMinusSCEV;
2287 case Instruction::Mul:
2288 Operation = &ScalarEvolution::getMulExpr;
2292 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2293 Signed ? &ScalarEvolution::getSignExtendExpr
2294 : &ScalarEvolution::getZeroExtendExpr;
2296 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2297 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2299 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2301 const SCEV *A = (this->*Extension)(
2302 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2303 const SCEV *B = (this->*Operation)((this->*Extension)(LHS, WideTy, 0),
2304 (this->*Extension)(RHS, WideTy, 0),
2305 SCEV::FlagAnyWrap, 0);
2309 std::pair<SCEV::NoWrapFlags, bool /*Deduced*/>
2310 ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
2311 const OverflowingBinaryOperator *OBO) {
2312 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2314 if (OBO->hasNoUnsignedWrap())
2315 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2316 if (OBO->hasNoSignedWrap())
2317 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2319 bool Deduced = false;
2321 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2322 return {Flags, Deduced};
2324 if (OBO->getOpcode() != Instruction::Add &&
2325 OBO->getOpcode() != Instruction::Sub &&
2326 OBO->getOpcode() != Instruction::Mul)
2327 return {Flags, Deduced};
2329 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2330 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2332 if (!OBO->hasNoUnsignedWrap() &&
2333 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2334 /* Signed */ false, LHS, RHS)) {
2335 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2339 if (!OBO->hasNoSignedWrap() &&
2340 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2341 /* Signed */ true, LHS, RHS)) {
2342 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2346 return {Flags, Deduced};
2349 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2350 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2351 // can't-overflow flags for the operation if possible.
2352 static SCEV::NoWrapFlags
2353 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
2354 const ArrayRef<const SCEV *> Ops,
2355 SCEV::NoWrapFlags Flags) {
2356 using namespace std::placeholders;
2358 using OBO = OverflowingBinaryOperator;
2361 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2363 assert(CanAnalyze && "don't call from other places!");
2365 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2366 SCEV::NoWrapFlags SignOrUnsignWrap =
2367 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2369 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2370 auto IsKnownNonNegative = [&](const SCEV *S) {
2371 return SE->isKnownNonNegative(S);
2374 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2376 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2378 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2380 if (SignOrUnsignWrap != SignOrUnsignMask &&
2381 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2382 isa<SCEVConstant>(Ops[0])) {
2387 return Instruction::Add;
2389 return Instruction::Mul;
2391 llvm_unreachable("Unexpected SCEV op.");
2395 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2397 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2398 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2399 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2400 Opcode, C, OBO::NoSignedWrap);
2401 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2402 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2405 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2406 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2407 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2408 Opcode, C, OBO::NoUnsignedWrap);
2409 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2410 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2414 // <0,+,nonnegative><nw> is also nuw
2415 // TODO: Add corresponding nsw case
2416 if (Type == scAddRecExpr && ScalarEvolution::hasFlags(Flags, SCEV::FlagNW) &&
2417 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2418 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2419 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2421 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2422 if (Type == scMulExpr && !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) &&
2424 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2425 if (UDiv->getOperand(1) == Ops[1])
2426 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2427 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2428 if (UDiv->getOperand(1) == Ops[0])
2429 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2435 bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
2436 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2439 /// Get a canonical add expression, or something simpler if possible.
2440 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2441 SCEV::NoWrapFlags OrigFlags,
2443 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2444 "only nuw or nsw allowed");
2445 assert(!Ops.empty() && "Cannot get empty add!");
2446 if (Ops.size() == 1) return Ops[0];
2448 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2449 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2450 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2451 "SCEVAddExpr operand types don't match!");
2452 unsigned NumPtrs = count_if(
2453 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2454 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2457 // Sort by complexity, this groups all similar expression types together.
2458 GroupByComplexity(Ops, &LI, DT);
2460 // If there are any constants, fold them together.
2462 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2464 assert(Idx < Ops.size());
2465 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2466 // We found two constants, fold them together!
2467 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2468 if (Ops.size() == 2) return Ops[0];
2469 Ops.erase(Ops.begin()+1); // Erase the folded element
2470 LHSC = cast<SCEVConstant>(Ops[0]);
2473 // If we are left with a constant zero being added, strip it off.
2474 if (LHSC->getValue()->isZero()) {
2475 Ops.erase(Ops.begin());
2479 if (Ops.size() == 1) return Ops[0];
2482 // Delay expensive flag strengthening until necessary.
2483 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2484 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2487 // Limit recursion calls depth.
2488 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2489 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2491 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2492 // Don't strengthen flags if we have no new information.
2493 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2494 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2495 Add->setNoWrapFlags(ComputeFlags(Ops));
2499 // Okay, check to see if the same value occurs in the operand list more than
2500 // once. If so, merge them together into an multiply expression. Since we
2501 // sorted the list, these values are required to be adjacent.
2502 Type *Ty = Ops[0]->getType();
2503 bool FoundMatch = false;
2504 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2505 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2506 // Scan ahead to count how many equal operands there are.
2508 while (i+Count != e && Ops[i+Count] == Ops[i])
2510 // Merge the values into a multiply.
2511 const SCEV *Scale = getConstant(Ty, Count);
2512 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2513 if (Ops.size() == Count)
2516 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2517 --i; e -= Count - 1;
2521 return getAddExpr(Ops, OrigFlags, Depth + 1);
2523 // Check for truncates. If all the operands are truncated from the same
2524 // type, see if factoring out the truncate would permit the result to be
2525 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2526 // if the contents of the resulting outer trunc fold to something simple.
2527 auto FindTruncSrcType = [&]() -> Type * {
2528 // We're ultimately looking to fold an addrec of truncs and muls of only
2529 // constants and truncs, so if we find any other types of SCEV
2530 // as operands of the addrec then we bail and return nullptr here.
2531 // Otherwise, we return the type of the operand of a trunc that we find.
2532 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2533 return T->getOperand()->getType();
2534 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2535 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2536 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2537 return T->getOperand()->getType();
2541 if (auto *SrcType = FindTruncSrcType()) {
2542 SmallVector<const SCEV *, 8> LargeOps;
2544 // Check all the operands to see if they can be represented in the
2545 // source type of the truncate.
2546 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2547 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2548 if (T->getOperand()->getType() != SrcType) {
2552 LargeOps.push_back(T->getOperand());
2553 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2554 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2555 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2556 SmallVector<const SCEV *, 8> LargeMulOps;
2557 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2558 if (const SCEVTruncateExpr *T =
2559 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2560 if (T->getOperand()->getType() != SrcType) {
2564 LargeMulOps.push_back(T->getOperand());
2565 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2566 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2573 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2580 // Evaluate the expression in the larger type.
2581 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2582 // If it folds to something simple, use it. Otherwise, don't.
2583 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2584 return getTruncateExpr(Fold, Ty);
2588 if (Ops.size() == 2) {
2589 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2590 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2592 const SCEV *A = Ops[0];
2593 const SCEV *B = Ops[1];
2594 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2595 auto *C = dyn_cast<SCEVConstant>(A);
2596 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2597 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2598 auto C2 = C->getAPInt();
2599 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2601 APInt ConstAdd = C1 + C2;
2602 auto AddFlags = AddExpr->getNoWrapFlags();
2603 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2604 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2607 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2610 // Adding a constant with the same sign and small magnitude is NSW, if the
2611 // original AddExpr was NSW.
2612 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2613 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2614 ConstAdd.abs().ule(C1.abs())) {
2616 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2619 if (PreservedFlags != SCEV::FlagAnyWrap) {
2620 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2621 NewOps[0] = getConstant(ConstAdd);
2622 return getAddExpr(NewOps, PreservedFlags);
2627 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2628 if (Ops.size() == 2) {
2629 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2630 if (Mul && Mul->getNumOperands() == 2 &&
2631 Mul->getOperand(0)->isAllOnesValue()) {
2634 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2635 return getMulExpr(Y, getUDivExpr(X, Y));
2640 // Skip past any other cast SCEVs.
2641 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2644 // If there are add operands they would be next.
2645 if (Idx < Ops.size()) {
2646 bool DeletedAdd = false;
2647 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2648 // common NUW flag for expression after inlining. Other flags cannot be
2649 // preserved, because they may depend on the original order of operations.
2650 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2651 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2652 if (Ops.size() > AddOpsInlineThreshold ||
2653 Add->getNumOperands() > AddOpsInlineThreshold)
2655 // If we have an add, expand the add operands onto the end of the operands
2657 Ops.erase(Ops.begin()+Idx);
2658 Ops.append(Add->op_begin(), Add->op_end());
2660 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2663 // If we deleted at least one add, we added operands to the end of the list,
2664 // and they are not necessarily sorted. Recurse to resort and resimplify
2665 // any operands we just acquired.
2667 return getAddExpr(Ops, CommonFlags, Depth + 1);
2670 // Skip over the add expression until we get to a multiply.
2671 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2674 // Check to see if there are any folding opportunities present with
2675 // operands multiplied by constant values.
2676 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2677 uint64_t BitWidth = getTypeSizeInBits(Ty);
2678 DenseMap<const SCEV *, APInt> M;
2679 SmallVector<const SCEV *, 8> NewOps;
2680 APInt AccumulatedConstant(BitWidth, 0);
2681 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2682 Ops.data(), Ops.size(),
2683 APInt(BitWidth, 1), *this)) {
2684 struct APIntCompare {
2685 bool operator()(const APInt &LHS, const APInt &RHS) const {
2686 return LHS.ult(RHS);
2690 // Some interesting folding opportunity is present, so its worthwhile to
2691 // re-generate the operands list. Group the operands by constant scale,
2692 // to avoid multiplying by the same constant scale multiple times.
2693 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2694 for (const SCEV *NewOp : NewOps)
2695 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2696 // Re-generate the operands list.
2698 if (AccumulatedConstant != 0)
2699 Ops.push_back(getConstant(AccumulatedConstant));
2700 for (auto &MulOp : MulOpLists) {
2701 if (MulOp.first == 1) {
2702 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2703 } else if (MulOp.first != 0) {
2704 Ops.push_back(getMulExpr(
2705 getConstant(MulOp.first),
2706 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2707 SCEV::FlagAnyWrap, Depth + 1));
2712 if (Ops.size() == 1)
2714 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2718 // If we are adding something to a multiply expression, make sure the
2719 // something is not already an operand of the multiply. If so, merge it into
2721 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2722 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2723 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2724 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2725 if (isa<SCEVConstant>(MulOpSCEV))
2727 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2728 if (MulOpSCEV == Ops[AddOp]) {
2729 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2730 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2731 if (Mul->getNumOperands() != 2) {
2732 // If the multiply has more than two operands, we must get the
2734 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2735 Mul->op_begin()+MulOp);
2736 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2737 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2739 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2740 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2741 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2742 SCEV::FlagAnyWrap, Depth + 1);
2743 if (Ops.size() == 2) return OuterMul;
2745 Ops.erase(Ops.begin()+AddOp);
2746 Ops.erase(Ops.begin()+Idx-1);
2748 Ops.erase(Ops.begin()+Idx);
2749 Ops.erase(Ops.begin()+AddOp-1);
2751 Ops.push_back(OuterMul);
2752 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2755 // Check this multiply against other multiplies being added together.
2756 for (unsigned OtherMulIdx = Idx+1;
2757 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2759 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2760 // If MulOp occurs in OtherMul, we can fold the two multiplies
2762 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2763 OMulOp != e; ++OMulOp)
2764 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2765 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2766 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2767 if (Mul->getNumOperands() != 2) {
2768 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2769 Mul->op_begin()+MulOp);
2770 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2771 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2773 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2774 if (OtherMul->getNumOperands() != 2) {
2775 SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2776 OtherMul->op_begin()+OMulOp);
2777 MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2778 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2780 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2781 const SCEV *InnerMulSum =
2782 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2783 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2784 SCEV::FlagAnyWrap, Depth + 1);
2785 if (Ops.size() == 2) return OuterMul;
2786 Ops.erase(Ops.begin()+Idx);
2787 Ops.erase(Ops.begin()+OtherMulIdx-1);
2788 Ops.push_back(OuterMul);
2789 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2795 // If there are any add recurrences in the operands list, see if any other
2796 // added values are loop invariant. If so, we can fold them into the
2798 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2801 // Scan over all recurrences, trying to fold loop invariants into them.
2802 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2803 // Scan all of the other operands to this add and add them to the vector if
2804 // they are loop invariant w.r.t. the recurrence.
2805 SmallVector<const SCEV *, 8> LIOps;
2806 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2807 const Loop *AddRecLoop = AddRec->getLoop();
2808 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2809 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2810 LIOps.push_back(Ops[i]);
2811 Ops.erase(Ops.begin()+i);
2815 // If we found some loop invariants, fold them into the recurrence.
2816 if (!LIOps.empty()) {
2817 // Compute nowrap flags for the addition of the loop-invariant ops and
2818 // the addrec. Temporarily push it as an operand for that purpose. These
2819 // flags are valid in the scope of the addrec only.
2820 LIOps.push_back(AddRec);
2821 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2824 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2825 LIOps.push_back(AddRec->getStart());
2827 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2829 // It is not in general safe to propagate flags valid on an add within
2830 // the addrec scope to one outside it. We must prove that the inner
2831 // scope is guaranteed to execute if the outer one does to be able to
2832 // safely propagate. We know the program is undefined if poison is
2833 // produced on the inner scoped addrec. We also know that *for this use*
2834 // the outer scoped add can't overflow (because of the flags we just
2835 // computed for the inner scoped add) without the program being undefined.
2836 // Proving that entry to the outer scope neccesitates entry to the inner
2837 // scope, thus proves the program undefined if the flags would be violated
2838 // in the outer scope.
2839 SCEV::NoWrapFlags AddFlags = Flags;
2840 if (AddFlags != SCEV::FlagAnyWrap) {
2841 auto *DefI = getDefiningScopeBound(LIOps);
2842 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2843 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2844 AddFlags = SCEV::FlagAnyWrap;
2846 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2848 // Build the new addrec. Propagate the NUW and NSW flags if both the
2849 // outer add and the inner addrec are guaranteed to have no overflow.
2850 // Always propagate NW.
2851 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2852 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2854 // If all of the other operands were loop invariant, we are done.
2855 if (Ops.size() == 1) return NewRec;
2857 // Otherwise, add the folded AddRec by the non-invariant parts.
2858 for (unsigned i = 0;; ++i)
2859 if (Ops[i] == AddRec) {
2863 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2866 // Okay, if there weren't any loop invariants to be folded, check to see if
2867 // there are multiple AddRec's with the same loop induction variable being
2868 // added together. If so, we can fold them.
2869 for (unsigned OtherIdx = Idx+1;
2870 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2872 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2873 // so that the 1st found AddRecExpr is dominated by all others.
2874 assert(DT.dominates(
2875 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2876 AddRec->getLoop()->getHeader()) &&
2877 "AddRecExprs are not sorted in reverse dominance order?");
2878 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2879 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2880 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2881 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2883 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2884 if (OtherAddRec->getLoop() == AddRecLoop) {
2885 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2887 if (i >= AddRecOps.size()) {
2888 AddRecOps.append(OtherAddRec->op_begin()+i,
2889 OtherAddRec->op_end());
2892 SmallVector<const SCEV *, 2> TwoOps = {
2893 AddRecOps[i], OtherAddRec->getOperand(i)};
2894 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2896 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2899 // Step size has changed, so we cannot guarantee no self-wraparound.
2900 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2901 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2905 // Otherwise couldn't fold anything into this recurrence. Move onto the
2909 // Okay, it looks like we really DO need an add expr. Check to see if we
2910 // already have one, otherwise create a new one.
2911 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2915 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2916 SCEV::NoWrapFlags Flags) {
2917 FoldingSetNodeID ID;
2918 ID.AddInteger(scAddExpr);
2919 for (const SCEV *Op : Ops)
2923 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2925 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2926 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2927 S = new (SCEVAllocator)
2928 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2929 UniqueSCEVs.InsertNode(S, IP);
2930 registerUser(S, Ops);
2932 S->setNoWrapFlags(Flags);
2937 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2938 const Loop *L, SCEV::NoWrapFlags Flags) {
2939 FoldingSetNodeID ID;
2940 ID.AddInteger(scAddRecExpr);
2941 for (const SCEV *Op : Ops)
2946 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2948 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2949 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2950 S = new (SCEVAllocator)
2951 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2952 UniqueSCEVs.InsertNode(S, IP);
2953 LoopUsers[L].push_back(S);
2954 registerUser(S, Ops);
2956 setNoWrapFlags(S, Flags);
2961 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2962 SCEV::NoWrapFlags Flags) {
2963 FoldingSetNodeID ID;
2964 ID.AddInteger(scMulExpr);
2965 for (const SCEV *Op : Ops)
2969 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2971 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2972 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2973 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2975 UniqueSCEVs.InsertNode(S, IP);
2976 registerUser(S, Ops);
2978 S->setNoWrapFlags(Flags);
2982 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2984 if (j > 1 && k / j != i) Overflow = true;
2988 /// Compute the result of "n choose k", the binomial coefficient. If an
2989 /// intermediate computation overflows, Overflow will be set and the return will
2990 /// be garbage. Overflow is not cleared on absence of overflow.
2991 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2992 // We use the multiplicative formula:
2993 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2994 // At each iteration, we take the n-th term of the numeral and divide by the
2995 // (k-n)th term of the denominator. This division will always produce an
2996 // integral result, and helps reduce the chance of overflow in the
2997 // intermediate computations. However, we can still overflow even when the
2998 // final result would fit.
3000 if (n == 0 || n == k) return 1;
3001 if (k > n) return 0;
3007 for (uint64_t i = 1; i <= k; ++i) {
3008 r = umul_ov(r, n-(i-1), Overflow);
3014 /// Determine if any of the operands in this SCEV are a constant or if
3015 /// any of the add or multiply expressions in this SCEV contain a constant.
3016 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3017 struct FindConstantInAddMulChain {
3018 bool FoundConstant = false;
3020 bool follow(const SCEV *S) {
3021 FoundConstant |= isa<SCEVConstant>(S);
3022 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3025 bool isDone() const {
3026 return FoundConstant;
3030 FindConstantInAddMulChain F;
3031 SCEVTraversal<FindConstantInAddMulChain> ST(F);
3032 ST.visitAll(StartExpr);
3033 return F.FoundConstant;
3036 /// Get a canonical multiply expression, or something simpler if possible.
3037 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
3038 SCEV::NoWrapFlags OrigFlags,
3040 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3041 "only nuw or nsw allowed");
3042 assert(!Ops.empty() && "Cannot get empty mul!");
3043 if (Ops.size() == 1) return Ops[0];
3045 Type *ETy = Ops[0]->getType();
3046 assert(!ETy->isPointerTy());
3047 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3048 assert(Ops[i]->getType() == ETy &&
3049 "SCEVMulExpr operand types don't match!");
3052 // Sort by complexity, this groups all similar expression types together.
3053 GroupByComplexity(Ops, &LI, DT);
3055 // If there are any constants, fold them together.
3057 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3059 assert(Idx < Ops.size());
3060 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3061 // We found two constants, fold them together!
3062 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3063 if (Ops.size() == 2) return Ops[0];
3064 Ops.erase(Ops.begin()+1); // Erase the folded element
3065 LHSC = cast<SCEVConstant>(Ops[0]);
3068 // If we have a multiply of zero, it will always be zero.
3069 if (LHSC->getValue()->isZero())
3072 // If we are left with a constant one being multiplied, strip it off.
3073 if (LHSC->getValue()->isOne()) {
3074 Ops.erase(Ops.begin());
3078 if (Ops.size() == 1)
3082 // Delay expensive flag strengthening until necessary.
3083 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3084 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3087 // Limit recursion calls depth.
3088 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3089 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3091 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3092 // Don't strengthen flags if we have no new information.
3093 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3094 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3095 Mul->setNoWrapFlags(ComputeFlags(Ops));
3099 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3100 if (Ops.size() == 2) {
3101 // C1*(C2+V) -> C1*C2 + C1*V
3102 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3103 // If any of Add's ops are Adds or Muls with a constant, apply this
3104 // transformation as well.
3106 // TODO: There are some cases where this transformation is not
3107 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3108 // this transformation should be narrowed down.
3109 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
3110 return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
3111 SCEV::FlagAnyWrap, Depth + 1),
3112 getMulExpr(LHSC, Add->getOperand(1),
3113 SCEV::FlagAnyWrap, Depth + 1),
3114 SCEV::FlagAnyWrap, Depth + 1);
3116 if (Ops[0]->isAllOnesValue()) {
3117 // If we have a mul by -1 of an add, try distributing the -1 among the
3119 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3120 SmallVector<const SCEV *, 4> NewOps;
3121 bool AnyFolded = false;
3122 for (const SCEV *AddOp : Add->operands()) {
3123 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3125 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3126 NewOps.push_back(Mul);
3129 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3130 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3131 // Negation preserves a recurrence's no self-wrap property.
3132 SmallVector<const SCEV *, 4> Operands;
3133 for (const SCEV *AddRecOp : AddRec->operands())
3134 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3137 return getAddRecExpr(Operands, AddRec->getLoop(),
3138 AddRec->getNoWrapFlags(SCEV::FlagNW));
3144 // Skip over the add expression until we get to a multiply.
3145 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3148 // If there are mul operands inline them all into this expression.
3149 if (Idx < Ops.size()) {
3150 bool DeletedMul = false;
3151 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3152 if (Ops.size() > MulOpsInlineThreshold)
3154 // If we have an mul, expand the mul operands onto the end of the
3156 Ops.erase(Ops.begin()+Idx);
3157 Ops.append(Mul->op_begin(), Mul->op_end());
3161 // If we deleted at least one mul, we added operands to the end of the
3162 // list, and they are not necessarily sorted. Recurse to resort and
3163 // resimplify any operands we just acquired.
3165 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3168 // If there are any add recurrences in the operands list, see if any other
3169 // added values are loop invariant. If so, we can fold them into the
3171 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3174 // Scan over all recurrences, trying to fold loop invariants into them.
3175 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3176 // Scan all of the other operands to this mul and add them to the vector
3177 // if they are loop invariant w.r.t. the recurrence.
3178 SmallVector<const SCEV *, 8> LIOps;
3179 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3180 const Loop *AddRecLoop = AddRec->getLoop();
3181 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3182 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3183 LIOps.push_back(Ops[i]);
3184 Ops.erase(Ops.begin()+i);
3188 // If we found some loop invariants, fold them into the recurrence.
3189 if (!LIOps.empty()) {
3190 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3191 SmallVector<const SCEV *, 4> NewOps;
3192 NewOps.reserve(AddRec->getNumOperands());
3193 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3194 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3195 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3196 SCEV::FlagAnyWrap, Depth + 1));
3198 // Build the new addrec. Propagate the NUW and NSW flags if both the
3199 // outer mul and the inner addrec are guaranteed to have no overflow.
3201 // No self-wrap cannot be guaranteed after changing the step size, but
3202 // will be inferred if either NUW or NSW is true.
3203 SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3204 const SCEV *NewRec = getAddRecExpr(
3205 NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3207 // If all of the other operands were loop invariant, we are done.
3208 if (Ops.size() == 1) return NewRec;
3210 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3211 for (unsigned i = 0;; ++i)
3212 if (Ops[i] == AddRec) {
3216 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3219 // Okay, if there weren't any loop invariants to be folded, check to see
3220 // if there are multiple AddRec's with the same loop induction variable
3221 // being multiplied together. If so, we can fold them.
3223 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3224 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3225 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3226 // ]]],+,...up to x=2n}.
3227 // Note that the arguments to choose() are always integers with values
3228 // known at compile time, never SCEV objects.
3230 // The implementation avoids pointless extra computations when the two
3231 // addrec's are of different length (mathematically, it's equivalent to
3232 // an infinite stream of zeros on the right).
3233 bool OpsModified = false;
3234 for (unsigned OtherIdx = Idx+1;
3235 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3237 const SCEVAddRecExpr *OtherAddRec =
3238 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3239 if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3242 // Limit max number of arguments to avoid creation of unreasonably big
3243 // SCEVAddRecs with very complex operands.
3244 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3245 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3248 bool Overflow = false;
3249 Type *Ty = AddRec->getType();
3250 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3251 SmallVector<const SCEV*, 7> AddRecOps;
3252 for (int x = 0, xe = AddRec->getNumOperands() +
3253 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3254 SmallVector <const SCEV *, 7> SumOps;
3255 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3256 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3257 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3258 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3259 z < ze && !Overflow; ++z) {
3260 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3262 if (LargerThan64Bits)
3263 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3265 Coeff = Coeff1*Coeff2;
3266 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3267 const SCEV *Term1 = AddRec->getOperand(y-z);
3268 const SCEV *Term2 = OtherAddRec->getOperand(z);
3269 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3270 SCEV::FlagAnyWrap, Depth + 1));
3274 SumOps.push_back(getZero(Ty));
3275 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3278 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3280 if (Ops.size() == 2) return NewAddRec;
3281 Ops[Idx] = NewAddRec;
3282 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3284 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3290 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3292 // Otherwise couldn't fold anything into this recurrence. Move onto the
3296 // Okay, it looks like we really DO need an mul expr. Check to see if we
3297 // already have one, otherwise create a new one.
3298 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3301 /// Represents an unsigned remainder expression based on unsigned division.
3302 const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
3304 assert(getEffectiveSCEVType(LHS->getType()) ==
3305 getEffectiveSCEVType(RHS->getType()) &&
3306 "SCEVURemExpr operand types don't match!");
3308 // Short-circuit easy cases
3309 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3310 // If constant is one, the result is trivial
3311 if (RHSC->getValue()->isOne())
3312 return getZero(LHS->getType()); // X urem 1 --> 0
3314 // If constant is a power of two, fold into a zext(trunc(LHS)).
3315 if (RHSC->getAPInt().isPowerOf2()) {
3316 Type *FullTy = LHS->getType();
3318 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3319 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3323 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3324 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3325 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3326 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3329 /// Get a canonical unsigned division expression, or something simpler if
3331 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
3333 assert(!LHS->getType()->isPointerTy() &&
3334 "SCEVUDivExpr operand can't be pointer!");
3335 assert(LHS->getType() == RHS->getType() &&
3336 "SCEVUDivExpr operand types don't match!");
3338 FoldingSetNodeID ID;
3339 ID.AddInteger(scUDivExpr);
3343 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3347 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3348 if (LHSC->getValue()->isZero())
3351 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3352 if (RHSC->getValue()->isOne())
3353 return LHS; // X udiv 1 --> x
3354 // If the denominator is zero, the result of the udiv is undefined. Don't
3355 // try to analyze it, because the resolution chosen here may differ from
3356 // the resolution chosen in other parts of the compiler.
3357 if (!RHSC->getValue()->isZero()) {
3358 // Determine if the division can be folded into the operands of
3360 // TODO: Generalize this to non-constants by using known-bits information.
3361 Type *Ty = LHS->getType();
3362 unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3363 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3364 // For non-power-of-two values, effectively round the value up to the
3365 // nearest power of two.
3366 if (!RHSC->getAPInt().isPowerOf2())
3368 IntegerType *ExtTy =
3369 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3370 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3371 if (const SCEVConstant *Step =
3372 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3373 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3374 const APInt &StepInt = Step->getAPInt();
3375 const APInt &DivInt = RHSC->getAPInt();
3376 if (!StepInt.urem(DivInt) &&
3377 getZeroExtendExpr(AR, ExtTy) ==
3378 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3379 getZeroExtendExpr(Step, ExtTy),
3380 AR->getLoop(), SCEV::FlagAnyWrap)) {
3381 SmallVector<const SCEV *, 4> Operands;
3382 for (const SCEV *Op : AR->operands())
3383 Operands.push_back(getUDivExpr(Op, RHS));
3384 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3386 /// Get a canonical UDivExpr for a recurrence.
3387 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3388 // We can currently only fold X%N if X is constant.
3389 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3390 if (StartC && !DivInt.urem(StepInt) &&
3391 getZeroExtendExpr(AR, ExtTy) ==
3392 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3393 getZeroExtendExpr(Step, ExtTy),
3394 AR->getLoop(), SCEV::FlagAnyWrap)) {
3395 const APInt &StartInt = StartC->getAPInt();
3396 const APInt &StartRem = StartInt.urem(StepInt);
3397 if (StartRem != 0) {
3398 const SCEV *NewLHS =
3399 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3400 AR->getLoop(), SCEV::FlagNW);
3401 if (LHS != NewLHS) {
3404 // Reset the ID to include the new LHS, and check if it is
3407 ID.AddInteger(scUDivExpr);
3411 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3417 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3418 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3419 SmallVector<const SCEV *, 4> Operands;
3420 for (const SCEV *Op : M->operands())
3421 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3422 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3423 // Find an operand that's safely divisible.
3424 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3425 const SCEV *Op = M->getOperand(i);
3426 const SCEV *Div = getUDivExpr(Op, RHSC);
3427 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3428 Operands = SmallVector<const SCEV *, 4>(M->operands());
3430 return getMulExpr(Operands);
3435 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3436 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3437 if (auto *DivisorConstant =
3438 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3439 bool Overflow = false;
3441 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3443 return getConstant(RHSC->getType(), 0, false);
3445 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3449 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3450 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3451 SmallVector<const SCEV *, 4> Operands;
3452 for (const SCEV *Op : A->operands())
3453 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3454 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3456 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3457 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3458 if (isa<SCEVUDivExpr>(Op) ||
3459 getMulExpr(Op, RHS) != A->getOperand(i))
3461 Operands.push_back(Op);
3463 if (Operands.size() == A->getNumOperands())
3464 return getAddExpr(Operands);
3468 // Fold if both operands are constant.
3469 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
3470 Constant *LHSCV = LHSC->getValue();
3471 Constant *RHSCV = RHSC->getValue();
3472 return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
3478 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3479 // changes). Make sure we get a new one.
3481 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3482 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3484 UniqueSCEVs.InsertNode(S, IP);
3485 registerUser(S, {LHS, RHS});
3489 APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3490 APInt A = C1->getAPInt().abs();
3491 APInt B = C2->getAPInt().abs();
3492 uint32_t ABW = A.getBitWidth();
3493 uint32_t BBW = B.getBitWidth();
3500 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3503 /// Get a canonical unsigned division expression, or something simpler if
3504 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3505 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3506 /// it's not exact because the udiv may be clearing bits.
3507 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
3509 // TODO: we could try to find factors in all sorts of things, but for now we
3510 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3511 // end of this file for inspiration.
3513 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3514 if (!Mul || !Mul->hasNoUnsignedWrap())
3515 return getUDivExpr(LHS, RHS);
3517 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3518 // If the mulexpr multiplies by a constant, then that constant must be the
3519 // first element of the mulexpr.
3520 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3521 if (LHSCst == RHSCst) {
3522 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3523 return getMulExpr(Operands);
3526 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3527 // that there's a factor provided by one of the other terms. We need to
3529 APInt Factor = gcd(LHSCst, RHSCst);
3530 if (!Factor.isIntN(1)) {
3532 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3534 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3535 SmallVector<const SCEV *, 2> Operands;
3536 Operands.push_back(LHSCst);
3537 Operands.append(Mul->op_begin() + 1, Mul->op_end());
3538 LHS = getMulExpr(Operands);
3540 Mul = dyn_cast<SCEVMulExpr>(LHS);
3542 return getUDivExactExpr(LHS, RHS);
3547 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3548 if (Mul->getOperand(i) == RHS) {
3549 SmallVector<const SCEV *, 2> Operands;
3550 Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3551 Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3552 return getMulExpr(Operands);
3556 return getUDivExpr(LHS, RHS);
3559 /// Get an add recurrence expression for the specified loop. Simplify the
3560 /// expression as much as possible.
3561 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3563 SCEV::NoWrapFlags Flags) {
3564 SmallVector<const SCEV *, 4> Operands;
3565 Operands.push_back(Start);
3566 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3567 if (StepChrec->getLoop() == L) {
3568 Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3569 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3572 Operands.push_back(Step);
3573 return getAddRecExpr(Operands, L, Flags);
3576 /// Get an add recurrence expression for the specified loop. Simplify the
3577 /// expression as much as possible.
3579 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
3580 const Loop *L, SCEV::NoWrapFlags Flags) {
3581 if (Operands.size() == 1) return Operands[0];
3583 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3584 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3585 assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
3586 "SCEVAddRecExpr operand types don't match!");
3587 assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3589 for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3590 assert(isLoopInvariant(Operands[i], L) &&
3591 "SCEVAddRecExpr operand is not loop-invariant!");
3594 if (Operands.back()->isZero()) {
3595 Operands.pop_back();
3596 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3599 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3600 // use that information to infer NUW and NSW flags. However, computing a
3601 // BE count requires calling getAddRecExpr, so we may not yet have a
3602 // meaningful BE count at this point (and if we don't, we'd be stuck
3603 // with a SCEVCouldNotCompute as the cached BE count).
3605 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3607 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3608 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3609 const Loop *NestedLoop = NestedAR->getLoop();
3610 if (L->contains(NestedLoop)
3611 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3612 : (!NestedLoop->contains(L) &&
3613 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3614 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3615 Operands[0] = NestedAR->getStart();
3616 // AddRecs require their operands be loop-invariant with respect to their
3617 // loops. Don't perform this transformation if it would break this
3619 bool AllInvariant = all_of(
3620 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3623 // Create a recurrence for the outer loop with the same step size.
3625 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3626 // inner recurrence has the same property.
3627 SCEV::NoWrapFlags OuterFlags =
3628 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3630 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3631 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3632 return isLoopInvariant(Op, NestedLoop);
3636 // Ok, both add recurrences are valid after the transformation.
3638 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3639 // the outer recurrence has the same property.
3640 SCEV::NoWrapFlags InnerFlags =
3641 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3642 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3645 // Reset Operands to its original state.
3646 Operands[0] = NestedAR;
3650 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3651 // already have one, otherwise create a new one.
3652 return getOrCreateAddRecExpr(Operands, L, Flags);
3656 ScalarEvolution::getGEPExpr(GEPOperator *GEP,
3657 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3658 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3659 // getSCEV(Base)->getType() has the same address space as Base->getType()
3660 // because SCEV::getType() preserves the address space.
3661 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3662 const bool AssumeInBoundsFlags = [&]() {
3663 if (!GEP->isInBounds())
3666 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3667 // but to do that, we have to ensure that said flag is valid in the entire
3668 // defined scope of the SCEV.
3669 auto *GEPI = dyn_cast<Instruction>(GEP);
3670 // TODO: non-instructions have global scope. We might be able to prove
3671 // some global scope cases
3672 return GEPI && isSCEVExprNeverPoison(GEPI);
3675 SCEV::NoWrapFlags OffsetWrap =
3676 AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3678 Type *CurTy = GEP->getType();
3679 bool FirstIter = true;
3680 SmallVector<const SCEV *, 4> Offsets;
3681 for (const SCEV *IndexExpr : IndexExprs) {
3682 // Compute the (potentially symbolic) offset in bytes for this index.
3683 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3684 // For a struct, add the member offset.
3685 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3686 unsigned FieldNo = Index->getZExtValue();
3687 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3688 Offsets.push_back(FieldOffset);
3690 // Update CurTy to the type of the field at Index.
3691 CurTy = STy->getTypeAtIndex(Index);
3693 // Update CurTy to its element type.
3695 assert(isa<PointerType>(CurTy) &&
3696 "The first index of a GEP indexes a pointer");
3697 CurTy = GEP->getSourceElementType();
3700 CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3702 // For an array, add the element offset, explicitly scaled.
3703 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3704 // Getelementptr indices are signed.
3705 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3707 // Multiply the index by the element size to compute the element offset.
3708 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3709 Offsets.push_back(LocalOffset);
3713 // Handle degenerate case of GEP without offsets.
3714 if (Offsets.empty())
3717 // Add the offsets together, assuming nsw if inbounds.
3718 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3719 // Add the base address and the offset. We cannot use the nsw flag, as the
3720 // base address is unsigned. However, if we know that the offset is
3721 // non-negative, we can use nuw.
3722 SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3723 ? SCEV::FlagNUW : SCEV::FlagAnyWrap;
3724 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3725 assert(BaseExpr->getType() == GEPExpr->getType() &&
3726 "GEP should not change type mid-flight.");
3730 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3731 ArrayRef<const SCEV *> Ops) {
3732 FoldingSetNodeID ID;
3733 ID.AddInteger(SCEVType);
3734 for (const SCEV *Op : Ops)
3737 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3740 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3741 SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3742 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3745 const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
3746 SmallVectorImpl<const SCEV *> &Ops) {
3747 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3748 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3749 if (Ops.size() == 1) return Ops[0];
3751 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3752 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3753 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3754 "Operand types don't match!");
3755 assert(Ops[0]->getType()->isPointerTy() ==
3756 Ops[i]->getType()->isPointerTy() &&
3757 "min/max should be consistently pointerish");
3761 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3762 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3764 // Sort by complexity, this groups all similar expression types together.
3765 GroupByComplexity(Ops, &LI, DT);
3767 // Check if we have created the same expression before.
3768 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3772 // If there are any constants, fold them together.
3774 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3776 assert(Idx < Ops.size());
3777 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3778 if (Kind == scSMaxExpr)
3779 return APIntOps::smax(LHS, RHS);
3780 else if (Kind == scSMinExpr)
3781 return APIntOps::smin(LHS, RHS);
3782 else if (Kind == scUMaxExpr)
3783 return APIntOps::umax(LHS, RHS);
3784 else if (Kind == scUMinExpr)
3785 return APIntOps::umin(LHS, RHS);
3786 llvm_unreachable("Unknown SCEV min/max opcode");
3789 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3790 // We found two constants, fold them together!
3791 ConstantInt *Fold = ConstantInt::get(
3792 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3793 Ops[0] = getConstant(Fold);
3794 Ops.erase(Ops.begin()+1); // Erase the folded element
3795 if (Ops.size() == 1) return Ops[0];
3796 LHSC = cast<SCEVConstant>(Ops[0]);
3799 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3800 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3802 if (IsMax ? IsMinV : IsMaxV) {
3803 // If we are left with a constant minimum(/maximum)-int, strip it off.
3804 Ops.erase(Ops.begin());
3806 } else if (IsMax ? IsMaxV : IsMinV) {
3807 // If we have a max(/min) with a constant maximum(/minimum)-int,
3808 // it will always be the extremum.
3812 if (Ops.size() == 1) return Ops[0];
3815 // Find the first operation of the same kind
3816 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3819 // Check to see if one of the operands is of the same kind. If so, expand its
3820 // operands onto our operand list, and recurse to simplify.
3821 if (Idx < Ops.size()) {
3822 bool DeletedAny = false;
3823 while (Ops[Idx]->getSCEVType() == Kind) {
3824 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3825 Ops.erase(Ops.begin()+Idx);
3826 Ops.append(SMME->op_begin(), SMME->op_end());
3831 return getMinMaxExpr(Kind, Ops);
3834 // Okay, check to see if the same value occurs in the operand list twice. If
3835 // so, delete one. Since we sorted the list, these values are required to
3837 llvm::CmpInst::Predicate GEPred =
3838 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
3839 llvm::CmpInst::Predicate LEPred =
3840 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
3841 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3842 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3843 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3844 if (Ops[i] == Ops[i + 1] ||
3845 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3846 // X op Y op Y --> X op Y
3847 // X op Y --> X, if we know X, Y are ordered appropriately
3848 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3851 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3853 // X op Y --> Y, if we know X, Y are ordered appropriately
3854 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3860 if (Ops.size() == 1) return Ops[0];
3862 assert(!Ops.empty() && "Reduced smax down to nothing!");
3864 // Okay, it looks like we really DO need an expr. Check to see if we
3865 // already have one, otherwise create a new one.
3866 FoldingSetNodeID ID;
3867 ID.AddInteger(Kind);
3868 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3869 ID.AddPointer(Ops[i]);
3871 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3873 return ExistingSCEV;
3874 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3875 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3876 SCEV *S = new (SCEVAllocator)
3877 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3879 UniqueSCEVs.InsertNode(S, IP);
3880 registerUser(S, Ops);
3886 class SCEVSequentialMinMaxDeduplicatingVisitor final
3887 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3888 Optional<const SCEV *>> {
3889 using RetVal = Optional<const SCEV *>;
3890 using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
3892 ScalarEvolution &SE;
3893 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3894 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3895 SmallPtrSet<const SCEV *, 16> SeenOps;
3897 bool canRecurseInto(SCEVTypes Kind) const {
3898 // We can only recurse into the SCEV expression of the same effective type
3899 // as the type of our root SCEV expression.
3900 return RootKind == Kind || NonSequentialRootKind == Kind;
3903 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3904 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3905 "Only for min/max expressions.");
3906 SCEVTypes Kind = S->getSCEVType();
3908 if (!canRecurseInto(Kind))
3911 auto *NAry = cast<SCEVNAryExpr>(S);
3912 SmallVector<const SCEV *> NewOps;
3914 visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
3921 return isa<SCEVSequentialMinMaxExpr>(S)
3922 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3923 : SE.getMinMaxExpr(Kind, NewOps);
3926 RetVal visit(const SCEV *S) {
3927 // Has the whole operand been seen already?
3928 if (!SeenOps.insert(S).second)
3930 return Base::visit(S);
3934 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
3936 : SE(SE), RootKind(RootKind),
3937 NonSequentialRootKind(
3938 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
3941 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
3942 SmallVectorImpl<const SCEV *> &NewOps) {
3943 bool Changed = false;
3944 SmallVector<const SCEV *> Ops;
3945 Ops.reserve(OrigOps.size());
3947 for (const SCEV *Op : OrigOps) {
3948 RetVal NewOp = visit(Op);
3952 Ops.emplace_back(*NewOp);
3956 NewOps = std::move(Ops);
3960 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
3962 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
3964 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
3966 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
3968 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
3970 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
3972 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
3974 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
3976 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
3978 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
3979 return visitAnyMinMaxExpr(Expr);
3982 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
3983 return visitAnyMinMaxExpr(Expr);
3986 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
3987 return visitAnyMinMaxExpr(Expr);
3990 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
3991 return visitAnyMinMaxExpr(Expr);
3994 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
3995 return visitAnyMinMaxExpr(Expr);
3998 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4000 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4006 ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
4007 SmallVectorImpl<const SCEV *> &Ops) {
4008 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4009 "Not a SCEVSequentialMinMaxExpr!");
4010 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4011 if (Ops.size() == 1)
4013 if (Ops.size() == 2 &&
4014 any_of(Ops, [](const SCEV *Op) { return isa<SCEVConstant>(Op); }))
4015 return getMinMaxExpr(
4016 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
4019 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4020 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4021 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4022 "Operand types don't match!");
4023 assert(Ops[0]->getType()->isPointerTy() ==
4024 Ops[i]->getType()->isPointerTy() &&
4025 "min/max should be consistently pointerish");
4029 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4030 // so we can *NOT* do any kind of sorting of the expressions!
4032 // Check if we have created the same expression before.
4033 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4036 // FIXME: there are *some* simplifications that we can do here.
4038 // Keep only the first instance of an operand.
4040 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4041 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4043 return getSequentialMinMaxExpr(Kind, Ops);
4046 // Check to see if one of the operands is of the same kind. If so, expand its
4047 // operands onto our operand list, and recurse to simplify.
4050 bool DeletedAny = false;
4051 while (Idx < Ops.size()) {
4052 if (Ops[Idx]->getSCEVType() != Kind) {
4056 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4057 Ops.erase(Ops.begin() + Idx);
4058 Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end());
4063 return getSequentialMinMaxExpr(Kind, Ops);
4066 // Okay, it looks like we really DO need an expr. Check to see if we
4067 // already have one, otherwise create a new one.
4068 FoldingSetNodeID ID;
4069 ID.AddInteger(Kind);
4070 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4071 ID.AddPointer(Ops[i]);
4073 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4075 return ExistingSCEV;
4077 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4078 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4079 SCEV *S = new (SCEVAllocator)
4080 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4082 UniqueSCEVs.InsertNode(S, IP);
4083 registerUser(S, Ops);
4087 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4088 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4089 return getSMaxExpr(Ops);
4092 const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4093 return getMinMaxExpr(scSMaxExpr, Ops);
4096 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4097 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4098 return getUMaxExpr(Ops);
4101 const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4102 return getMinMaxExpr(scUMaxExpr, Ops);
4105 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
4107 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4108 return getSMinExpr(Ops);
4111 const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
4112 return getMinMaxExpr(scSMinExpr, Ops);
4115 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4117 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4118 return getUMinExpr(Ops, Sequential);
4121 const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
4123 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4124 : getMinMaxExpr(scUMinExpr, Ops);
4128 ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy,
4129 ScalableVectorType *ScalableTy) {
4130 Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
4131 Constant *One = ConstantInt::get(IntTy, 1);
4132 Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
4133 // Note that the expression we created is the final expression, we don't
4134 // want to simplify it any further Also, if we call a normal getSCEV(),
4135 // we'll end up in an endless recursion. So just create an SCEVUnknown.
4136 return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
4139 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
4140 if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
4141 return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
4142 // We can bypass creating a target-independent constant expression and then
4143 // folding it back into a ConstantInt. This is just a compile-time
4145 return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4148 const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) {
4149 if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
4150 return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
4151 // We can bypass creating a target-independent constant expression and then
4152 // folding it back into a ConstantInt. This is just a compile-time
4154 return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4157 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
4160 // We can bypass creating a target-independent constant expression and then
4161 // folding it back into a ConstantInt. This is just a compile-time
4164 IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
4167 const SCEV *ScalarEvolution::getUnknown(Value *V) {
4168 // Don't attempt to do anything other than create a SCEVUnknown object
4169 // here. createSCEV only calls getUnknown after checking for all other
4170 // interesting possibilities, and any other code that calls getUnknown
4171 // is doing so in order to hide a value from SCEV canonicalization.
4173 FoldingSetNodeID ID;
4174 ID.AddInteger(scUnknown);
4177 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4178 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4179 "Stale SCEVUnknown in uniquing map!");
4182 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4184 FirstUnknown = cast<SCEVUnknown>(S);
4185 UniqueSCEVs.InsertNode(S, IP);
4189 //===----------------------------------------------------------------------===//
4190 // Basic SCEV Analysis and PHI Idiom Recognition Code
4193 /// Test if values of the given type are analyzable within the SCEV
4194 /// framework. This primarily includes integer types, and it can optionally
4195 /// include pointer types if the ScalarEvolution class has access to
4196 /// target-specific information.
4197 bool ScalarEvolution::isSCEVable(Type *Ty) const {
4198 // Integers and pointers are always SCEVable.
4199 return Ty->isIntOrPtrTy();
4202 /// Return the size in bits of the specified type, for which isSCEVable must
4204 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
4205 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4206 if (Ty->isPointerTy())
4207 return getDataLayout().getIndexTypeSizeInBits(Ty);
4208 return getDataLayout().getTypeSizeInBits(Ty);
4211 /// Return a type with the same bitwidth as the given type and which represents
4212 /// how SCEV will treat the given type, for which isSCEVable must return
4213 /// true. For pointer types, this is the pointer index sized integer type.
4214 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
4215 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4217 if (Ty->isIntegerTy())
4220 // The only other support type is pointer.
4221 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4222 return getDataLayout().getIndexType(Ty);
4225 Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const {
4226 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4229 bool ScalarEvolution::instructionCouldExistWitthOperands(const SCEV *A,
4231 /// For a valid use point to exist, the defining scope of one operand
4232 /// must dominate the other.
4233 bool PreciseA, PreciseB;
4234 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4235 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4236 if (!PreciseA || !PreciseB)
4239 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4240 DT.dominates(ScopeB, ScopeA);
4244 const SCEV *ScalarEvolution::getCouldNotCompute() {
4245 return CouldNotCompute.get();
4248 bool ScalarEvolution::checkValidity(const SCEV *S) const {
4249 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4250 auto *SU = dyn_cast<SCEVUnknown>(S);
4251 return SU && SU->getValue() == nullptr;
4254 return !ContainsNulls;
4257 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
4258 HasRecMapType::iterator I = HasRecMap.find(S);
4259 if (I != HasRecMap.end())
4263 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4264 HasRecMap.insert({S, FoundAddRec});
4268 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
4269 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
4270 /// offset I, then return {S', I}, else return {\p S, nullptr}.
4271 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
4272 const auto *Add = dyn_cast<SCEVAddExpr>(S);
4274 return {S, nullptr};
4276 if (Add->getNumOperands() != 2)
4277 return {S, nullptr};
4279 auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
4281 return {S, nullptr};
4283 return {Add->getOperand(1), ConstOp->getValue()};
4286 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4287 /// by the value and offset from any ValueOffsetPair in the set.
4288 ScalarEvolution::ValueOffsetPairSetVector *
4289 ScalarEvolution::getSCEVValues(const SCEV *S) {
4290 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4291 if (SI == ExprValueMap.end())
4294 if (VerifySCEVMap) {
4295 // Check there is no dangling Value in the set returned.
4296 for (const auto &VE : SI->second)
4297 assert(ValueExprMap.count(VE.first));
4303 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4304 /// cannot be used separately. eraseValueFromMap should be used to remove
4305 /// V from ValueExprMap and ExprValueMap at the same time.
4306 void ScalarEvolution::eraseValueFromMap(Value *V) {
4307 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4308 if (I != ValueExprMap.end()) {
4309 const SCEV *S = I->second;
4310 // Remove {V, 0} from the set of ExprValueMap[S]
4311 if (auto *SV = getSCEVValues(S))
4312 SV->remove({V, nullptr});
4314 // Remove {V, Offset} from the set of ExprValueMap[Stripped]
4315 const SCEV *Stripped;
4316 ConstantInt *Offset;
4317 std::tie(Stripped, Offset) = splitAddExpr(S);
4318 if (Offset != nullptr) {
4319 if (auto *SV = getSCEVValues(Stripped))
4320 SV->remove({V, Offset});
4322 ValueExprMap.erase(V);
4326 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4327 // A recursive query may have already computed the SCEV. It should be
4328 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4329 // inferred nowrap flags.
4330 auto It = ValueExprMap.find_as(V);
4331 if (It == ValueExprMap.end()) {
4332 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4333 ExprValueMap[S].insert({V, nullptr});
4337 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4338 /// create a new one.
4339 const SCEV *ScalarEvolution::getSCEV(Value *V) {
4340 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4342 const SCEV *S = getExistingSCEV(V);
4345 // During PHI resolution, it is possible to create two SCEVs for the same
4346 // V, so it is needed to double check whether V->S is inserted into
4347 // ValueExprMap before insert S->{V, 0} into ExprValueMap.
4348 std::pair<ValueExprMapType::iterator, bool> Pair =
4349 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4351 ExprValueMap[S].insert({V, nullptr});
4353 // If S == Stripped + Offset, add Stripped -> {V, Offset} into
4355 const SCEV *Stripped = S;
4356 ConstantInt *Offset = nullptr;
4357 std::tie(Stripped, Offset) = splitAddExpr(S);
4358 // If stripped is SCEVUnknown, don't bother to save
4359 // Stripped -> {V, offset}. It doesn't simplify and sometimes even
4360 // increase the complexity of the expansion code.
4361 // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
4362 // because it may generate add/sub instead of GEP in SCEV expansion.
4363 if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
4364 !isa<GetElementPtrInst>(V))
4365 ExprValueMap[Stripped].insert({V, Offset});
4371 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4372 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4374 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4375 if (I != ValueExprMap.end()) {
4376 const SCEV *S = I->second;
4377 assert(checkValidity(S) &&
4378 "existing SCEV has not been properly invalidated");
4384 /// Return a SCEV corresponding to -V = -1*V
4385 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
4386 SCEV::NoWrapFlags Flags) {
4387 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4389 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4391 Type *Ty = V->getType();
4392 Ty = getEffectiveSCEVType(Ty);
4393 return getMulExpr(V, getMinusOne(Ty), Flags);
4396 /// If Expr computes ~A, return A else return nullptr
4397 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4398 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4399 if (!Add || Add->getNumOperands() != 2 ||
4400 !Add->getOperand(0)->isAllOnesValue())
4403 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4404 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4405 !AddRHS->getOperand(0)->isAllOnesValue())
4408 return AddRHS->getOperand(1);
4411 /// Return a SCEV corresponding to ~V = -1-V
4412 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
4413 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4415 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4417 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4419 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4420 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4421 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4422 SmallVector<const SCEV *, 2> MatchedOperands;
4423 for (const SCEV *Operand : MME->operands()) {
4424 const SCEV *Matched = MatchNotExpr(Operand);
4426 return (const SCEV *)nullptr;
4427 MatchedOperands.push_back(Matched);
4429 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4432 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4436 Type *Ty = V->getType();
4437 Ty = getEffectiveSCEVType(Ty);
4438 return getMinusSCEV(getMinusOne(Ty), V);
4441 const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
4442 assert(P->getType()->isPointerTy());
4444 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4445 // The base of an AddRec is the first operand.
4446 SmallVector<const SCEV *> Ops{AddRec->operands()};
4447 Ops[0] = removePointerBase(Ops[0]);
4448 // Don't try to transfer nowrap flags for now. We could in some cases
4449 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4450 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4452 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4453 // The base of an Add is the pointer operand.
4454 SmallVector<const SCEV *> Ops{Add->operands()};
4455 const SCEV **PtrOp = nullptr;
4456 for (const SCEV *&AddOp : Ops) {
4457 if (AddOp->getType()->isPointerTy()) {
4458 assert(!PtrOp && "Cannot have multiple pointer ops");
4462 *PtrOp = removePointerBase(*PtrOp);
4463 // Don't try to transfer nowrap flags for now. We could in some cases
4464 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4465 return getAddExpr(Ops);
4467 // Any other expression must be a pointer base.
4468 return getZero(P->getType());
4471 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4472 SCEV::NoWrapFlags Flags,
4474 // Fast path: X - X --> 0.
4476 return getZero(LHS->getType());
4478 // If we subtract two pointers with different pointer bases, bail.
4479 // Eventually, we're going to add an assertion to getMulExpr that we
4480 // can't multiply by a pointer.
4481 if (RHS->getType()->isPointerTy()) {
4482 if (!LHS->getType()->isPointerTy() ||
4483 getPointerBase(LHS) != getPointerBase(RHS))
4484 return getCouldNotCompute();
4485 LHS = removePointerBase(LHS);
4486 RHS = removePointerBase(RHS);
4489 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4490 // makes it so that we cannot make much use of NUW.
4491 auto AddFlags = SCEV::FlagAnyWrap;
4492 const bool RHSIsNotMinSigned =
4493 !getSignedRangeMin(RHS).isMinSignedValue();
4494 if (hasFlags(Flags, SCEV::FlagNSW)) {
4495 // Let M be the minimum representable signed value. Then (-1)*RHS
4496 // signed-wraps if and only if RHS is M. That can happen even for
4497 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4498 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4499 // (-1)*RHS, we need to prove that RHS != M.
4501 // If LHS is non-negative and we know that LHS - RHS does not
4502 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4503 // either by proving that RHS > M or that LHS >= 0.
4504 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4505 AddFlags = SCEV::FlagNSW;
4509 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4510 // RHS is NSW and LHS >= 0.
4512 // The difficulty here is that the NSW flag may have been proven
4513 // relative to a loop that is to be found in a recurrence in LHS and
4514 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4515 // larger scope than intended.
4516 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4518 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4521 const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
4523 Type *SrcTy = V->getType();
4524 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4525 "Cannot truncate or zero extend with non-integer arguments!");
4526 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4527 return V; // No conversion
4528 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4529 return getTruncateExpr(V, Ty, Depth);
4530 return getZeroExtendExpr(V, Ty, Depth);
4533 const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
4535 Type *SrcTy = V->getType();
4536 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4537 "Cannot truncate or zero extend with non-integer arguments!");
4538 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4539 return V; // No conversion
4540 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4541 return getTruncateExpr(V, Ty, Depth);
4542 return getSignExtendExpr(V, Ty, Depth);
4546 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
4547 Type *SrcTy = V->getType();
4548 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4549 "Cannot noop or zero extend with non-integer arguments!");
4550 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4551 "getNoopOrZeroExtend cannot truncate!");
4552 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4553 return V; // No conversion
4554 return getZeroExtendExpr(V, Ty);
4558 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
4559 Type *SrcTy = V->getType();
4560 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4561 "Cannot noop or sign extend with non-integer arguments!");
4562 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4563 "getNoopOrSignExtend cannot truncate!");
4564 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4565 return V; // No conversion
4566 return getSignExtendExpr(V, Ty);
4570 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
4571 Type *SrcTy = V->getType();
4572 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4573 "Cannot noop or any extend with non-integer arguments!");
4574 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4575 "getNoopOrAnyExtend cannot truncate!");
4576 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4577 return V; // No conversion
4578 return getAnyExtendExpr(V, Ty);
4582 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
4583 Type *SrcTy = V->getType();
4584 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4585 "Cannot truncate or noop with non-integer arguments!");
4586 assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4587 "getTruncateOrNoop cannot extend!");
4588 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4589 return V; // No conversion
4590 return getTruncateExpr(V, Ty);
4593 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
4595 const SCEV *PromotedLHS = LHS;
4596 const SCEV *PromotedRHS = RHS;
4598 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4599 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4601 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4603 return getUMaxExpr(PromotedLHS, PromotedRHS);
4606 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
4609 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4610 return getUMinFromMismatchedTypes(Ops, Sequential);
4614 ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
4616 assert(!Ops.empty() && "At least one operand must be!");
4618 if (Ops.size() == 1)
4621 // Find the max type first.
4622 Type *MaxType = nullptr;
4625 MaxType = getWiderType(MaxType, S->getType());
4627 MaxType = S->getType();
4628 assert(MaxType && "Failed to find maximum type!");
4630 // Extend all ops to max type.
4631 SmallVector<const SCEV *, 2> PromotedOps;
4633 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4636 return getUMinExpr(PromotedOps, Sequential);
4639 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
4640 // A pointer operand may evaluate to a nonpointer expression, such as null.
4641 if (!V->getType()->isPointerTy())
4645 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4646 V = AddRec->getStart();
4647 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4648 const SCEV *PtrOp = nullptr;
4649 for (const SCEV *AddOp : Add->operands()) {
4650 if (AddOp->getType()->isPointerTy()) {
4651 assert(!PtrOp && "Cannot have multiple pointer ops");
4655 assert(PtrOp && "Must have pointer op");
4657 } else // Not something we can look further into.
4662 /// Push users of the given Instruction onto the given Worklist.
4663 static void PushDefUseChildren(Instruction *I,
4664 SmallVectorImpl<Instruction *> &Worklist,
4665 SmallPtrSetImpl<Instruction *> &Visited) {
4666 // Push the def-use children onto the Worklist stack.
4667 for (User *U : I->users()) {
4668 auto *UserInsn = cast<Instruction>(U);
4669 if (Visited.insert(UserInsn).second)
4670 Worklist.push_back(UserInsn);
4676 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4677 /// expression in case its Loop is L. If it is not L then
4678 /// if IgnoreOtherLoops is true then use AddRec itself
4679 /// otherwise rewrite cannot be done.
4680 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4681 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4683 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4684 bool IgnoreOtherLoops = true) {
4685 SCEVInitRewriter Rewriter(L, SE);
4686 const SCEV *Result = Rewriter.visit(S);
4687 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4688 return SE.getCouldNotCompute();
4689 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4690 ? SE.getCouldNotCompute()
4694 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4695 if (!SE.isLoopInvariant(Expr, L))
4696 SeenLoopVariantSCEVUnknown = true;
4700 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4701 // Only re-write AddRecExprs for this loop.
4702 if (Expr->getLoop() == L)
4703 return Expr->getStart();
4704 SeenOtherLoops = true;
4708 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4710 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4713 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4714 : SCEVRewriteVisitor(SE), L(L) {}
4717 bool SeenLoopVariantSCEVUnknown = false;
4718 bool SeenOtherLoops = false;
4721 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4722 /// increment expression in case its Loop is L. If it is not L then
4723 /// use AddRec itself.
4724 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4725 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4727 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4728 SCEVPostIncRewriter Rewriter(L, SE);
4729 const SCEV *Result = Rewriter.visit(S);
4730 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4731 ? SE.getCouldNotCompute()
4735 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4736 if (!SE.isLoopInvariant(Expr, L))
4737 SeenLoopVariantSCEVUnknown = true;
4741 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4742 // Only re-write AddRecExprs for this loop.
4743 if (Expr->getLoop() == L)
4744 return Expr->getPostIncExpr(SE);
4745 SeenOtherLoops = true;
4749 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4751 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4754 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4755 : SCEVRewriteVisitor(SE), L(L) {}
4758 bool SeenLoopVariantSCEVUnknown = false;
4759 bool SeenOtherLoops = false;
4762 /// This class evaluates the compare condition by matching it against the
4763 /// condition of loop latch. If there is a match we assume a true value
4764 /// for the condition while building SCEV nodes.
4765 class SCEVBackedgeConditionFolder
4766 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4768 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4769 ScalarEvolution &SE) {
4770 bool IsPosBECond = false;
4771 Value *BECond = nullptr;
4772 if (BasicBlock *Latch = L->getLoopLatch()) {
4773 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4774 if (BI && BI->isConditional()) {
4775 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4776 "Both outgoing branches should not target same header!");
4777 BECond = BI->getCondition();
4778 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4783 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4784 return Rewriter.visit(S);
4787 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4788 const SCEV *Result = Expr;
4789 bool InvariantF = SE.isLoopInvariant(Expr, L);
4792 Instruction *I = cast<Instruction>(Expr->getValue());
4793 switch (I->getOpcode()) {
4794 case Instruction::Select: {
4795 SelectInst *SI = cast<SelectInst>(I);
4796 Optional<const SCEV *> Res =
4797 compareWithBackedgeCondition(SI->getCondition());
4798 if (Res.hasValue()) {
4799 bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
4800 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4805 Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4807 Result = Res.getValue();
4816 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4817 bool IsPosBECond, ScalarEvolution &SE)
4818 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4819 IsPositiveBECond(IsPosBECond) {}
4821 Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4824 /// Loop back condition.
4825 Value *BackedgeCond = nullptr;
4826 /// Set to true if loop back is on positive branch condition.
4827 bool IsPositiveBECond;
4830 Optional<const SCEV *>
4831 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4833 // If value matches the backedge condition for loop latch,
4834 // then return a constant evolution node based on loopback
4836 if (BackedgeCond == IC)
4837 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4838 : SE.getZero(Type::getInt1Ty(SE.getContext()));
4842 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4844 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4845 ScalarEvolution &SE) {
4846 SCEVShiftRewriter Rewriter(L, SE);
4847 const SCEV *Result = Rewriter.visit(S);
4848 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4851 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4852 // Only allow AddRecExprs for this loop.
4853 if (!SE.isLoopInvariant(Expr, L))
4858 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4859 if (Expr->getLoop() == L && Expr->isAffine())
4860 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4865 bool isValid() { return Valid; }
4868 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
4869 : SCEVRewriteVisitor(SE), L(L) {}
4875 } // end anonymous namespace
4878 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4879 if (!AR->isAffine())
4880 return SCEV::FlagAnyWrap;
4882 using OBO = OverflowingBinaryOperator;
4884 SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
4886 if (!AR->hasNoSignedWrap()) {
4887 ConstantRange AddRecRange = getSignedRange(AR);
4888 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4890 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4891 Instruction::Add, IncRange, OBO::NoSignedWrap);
4892 if (NSWRegion.contains(AddRecRange))
4893 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
4896 if (!AR->hasNoUnsignedWrap()) {
4897 ConstantRange AddRecRange = getUnsignedRange(AR);
4898 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4900 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4901 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4902 if (NUWRegion.contains(AddRecRange))
4903 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
4910 ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4911 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
4913 if (AR->hasNoSignedWrap())
4916 if (!AR->isAffine())
4919 const SCEV *Step = AR->getStepRecurrence(*this);
4920 const Loop *L = AR->getLoop();
4922 // Check whether the backedge-taken count is SCEVCouldNotCompute.
4923 // Note that this serves two purposes: It filters out loops that are
4924 // simply not analyzable, and it covers the case where this code is
4925 // being called from within backedge-taken count analysis, such that
4926 // attempting to ask for the backedge-taken count would likely result
4927 // in infinite recursion. In the later case, the analysis code will
4928 // cope with a conservative value, and it will take care to purge
4929 // that value once it has finished.
4930 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
4932 // Normally, in the cases we can prove no-overflow via a
4933 // backedge guarding condition, we can also compute a backedge
4934 // taken count for the loop. The exceptions are assumptions and
4935 // guards present in the loop -- SCEV is not great at exploiting
4936 // these to compute max backedge taken counts, but can still use
4937 // these to prove lack of overflow. Use this fact to avoid
4938 // doing extra work that may not pay off.
4940 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
4941 AC.assumptions().empty())
4944 // If the backedge is guarded by a comparison with the pre-inc value the
4945 // addrec is safe. Also, if the entry is guarded by a comparison with the
4946 // start value and the backedge is guarded by a comparison with the post-inc
4947 // value, the addrec is safe.
4948 ICmpInst::Predicate Pred;
4949 const SCEV *OverflowLimit =
4950 getSignedOverflowLimitForStep(Step, &Pred, this);
4951 if (OverflowLimit &&
4952 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
4953 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
4954 Result = setFlags(Result, SCEV::FlagNSW);
4959 ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4960 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
4962 if (AR->hasNoUnsignedWrap())
4965 if (!AR->isAffine())
4968 const SCEV *Step = AR->getStepRecurrence(*this);
4969 unsigned BitWidth = getTypeSizeInBits(AR->getType());
4970 const Loop *L = AR->getLoop();
4972 // Check whether the backedge-taken count is SCEVCouldNotCompute.
4973 // Note that this serves two purposes: It filters out loops that are
4974 // simply not analyzable, and it covers the case where this code is
4975 // being called from within backedge-taken count analysis, such that
4976 // attempting to ask for the backedge-taken count would likely result
4977 // in infinite recursion. In the later case, the analysis code will
4978 // cope with a conservative value, and it will take care to purge
4979 // that value once it has finished.
4980 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
4982 // Normally, in the cases we can prove no-overflow via a
4983 // backedge guarding condition, we can also compute a backedge
4984 // taken count for the loop. The exceptions are assumptions and
4985 // guards present in the loop -- SCEV is not great at exploiting
4986 // these to compute max backedge taken counts, but can still use
4987 // these to prove lack of overflow. Use this fact to avoid
4988 // doing extra work that may not pay off.
4990 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
4991 AC.assumptions().empty())
4994 // If the backedge is guarded by a comparison with the pre-inc value the
4995 // addrec is safe. Also, if the entry is guarded by a comparison with the
4996 // start value and the backedge is guarded by a comparison with the post-inc
4997 // value, the addrec is safe.
4998 if (isKnownPositive(Step)) {
4999 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5000 getUnsignedRangeMax(Step));
5001 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
5002 isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
5003 Result = setFlags(Result, SCEV::FlagNUW);
5012 /// Represents an abstract binary operation. This may exist as a
5013 /// normal instruction or constant expression, or may have been
5014 /// derived from an expression tree.
5022 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5023 /// constant expression.
5024 Operator *Op = nullptr;
5026 explicit BinaryOp(Operator *Op)
5027 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5029 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5030 IsNSW = OBO->hasNoSignedWrap();
5031 IsNUW = OBO->hasNoUnsignedWrap();
5035 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5037 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5040 } // end anonymous namespace
5042 /// Try to map \p V into a BinaryOp, and return \c None on failure.
5043 static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
5044 auto *Op = dyn_cast<Operator>(V);
5048 // Implementation detail: all the cleverness here should happen without
5049 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5050 // SCEV expressions when possible, and we should not break that.
5052 switch (Op->getOpcode()) {
5053 case Instruction::Add:
5054 case Instruction::Sub:
5055 case Instruction::Mul:
5056 case Instruction::UDiv:
5057 case Instruction::URem:
5058 case Instruction::And:
5059 case Instruction::Or:
5060 case Instruction::AShr:
5061 case Instruction::Shl:
5062 return BinaryOp(Op);
5064 case Instruction::Xor:
5065 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5066 // If the RHS of the xor is a signmask, then this is just an add.
5067 // Instcombine turns add of signmask into xor as a strength reduction step.
5068 if (RHSC->getValue().isSignMask())
5069 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5070 return BinaryOp(Op);
5072 case Instruction::LShr:
5073 // Turn logical shift right of a constant into a unsigned divide.
5074 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5075 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5077 // If the shift count is not less than the bitwidth, the result of
5078 // the shift is undefined. Don't try to analyze it, because the
5079 // resolution chosen here may differ from the resolution chosen in
5080 // other parts of the compiler.
5081 if (SA->getValue().ult(BitWidth)) {
5083 ConstantInt::get(SA->getContext(),
5084 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5085 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5088 return BinaryOp(Op);
5090 case Instruction::ExtractValue: {
5091 auto *EVI = cast<ExtractValueInst>(Op);
5092 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5095 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5099 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5100 bool Signed = WO->isSigned();
5101 // TODO: Should add nuw/nsw flags for mul as well.
5102 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5103 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5105 // Now that we know that all uses of the arithmetic-result component of
5106 // CI are guarded by the overflow check, we can go ahead and pretend
5107 // that the arithmetic is non-overflowing.
5108 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5109 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5116 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5117 // semantics as a Sub, return a binary sub expression.
5118 if (auto *II = dyn_cast<IntrinsicInst>(V))
5119 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5120 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5125 /// Helper function to createAddRecFromPHIWithCasts. We have a phi
5126 /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5127 /// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5128 /// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5129 /// follows one of the following patterns:
5130 /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5131 /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5132 /// If the SCEV expression of \p Op conforms with one of the expected patterns
5133 /// we return the type of the truncation operation, and indicate whether the
5134 /// truncated type should be treated as signed/unsigned by setting
5135 /// \p Signed to true/false, respectively.
5136 static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5137 bool &Signed, ScalarEvolution &SE) {
5138 // The case where Op == SymbolicPHI (that is, with no type conversions on
5139 // the way) is handled by the regular add recurrence creating logic and
5140 // would have already been triggered in createAddRecForPHI. Reaching it here
5141 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5142 // because one of the other operands of the SCEVAddExpr updating this PHI is
5145 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5146 // this case predicates that allow us to prove that Op == SymbolicPHI will
5148 if (Op == SymbolicPHI)
5151 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5152 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5153 if (SourceBits != NewBits)
5156 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5157 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5160 const SCEVTruncateExpr *Trunc =
5161 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5162 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5165 const SCEV *X = Trunc->getOperand();
5166 if (X != SymbolicPHI)
5168 Signed = SExt != nullptr;
5169 return Trunc->getType();
5172 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5173 if (!PN->getType()->isIntegerTy())
5175 const Loop *L = LI.getLoopFor(PN->getParent());
5176 if (!L || L->getHeader() != PN->getParent())
5181 // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5182 // computation that updates the phi follows the following pattern:
5183 // (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5184 // which correspond to a phi->trunc->sext/zext->add->phi update chain.
5185 // If so, try to see if it can be rewritten as an AddRecExpr under some
5186 // Predicates. If successful, return them as a pair. Also cache the results
5189 // Example usage scenario:
5190 // Say the Rewriter is called for the following SCEV:
5191 // 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5193 // %X = phi i64 (%Start, %BEValue)
5194 // It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5195 // and call this function with %SymbolicPHI = %X.
5197 // The analysis will find that the value coming around the backedge has
5198 // the following SCEV:
5199 // BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5200 // Upon concluding that this matches the desired pattern, the function
5201 // will return the pair {NewAddRec, SmallPredsVec} where:
5202 // NewAddRec = {%Start,+,%Step}
5203 // SmallPredsVec = {P1, P2, P3} as follows:
5204 // P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5205 // P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5206 // P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5207 // The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5208 // under the predicates {P1,P2,P3}.
5209 // This predicated rewrite will be cached in PredicatedSCEVRewrites:
5210 // PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5214 // 1) Extend the Induction descriptor to also support inductions that involve
5215 // casts: When needed (namely, when we are called in the context of the
5216 // vectorizer induction analysis), a Set of cast instructions will be
5217 // populated by this method, and provided back to isInductionPHI. This is
5218 // needed to allow the vectorizer to properly record them to be ignored by
5219 // the cost model and to avoid vectorizing them (otherwise these casts,
5220 // which are redundant under the runtime overflow checks, will be
5221 // vectorized, which can be costly).
5223 // 2) Support additional induction/PHISCEV patterns: We also want to support
5224 // inductions where the sext-trunc / zext-trunc operations (partly) occur
5225 // after the induction update operation (the induction increment):
5227 // (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5228 // which correspond to a phi->add->trunc->sext/zext->phi update chain.
5230 // (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5231 // which correspond to a phi->trunc->add->sext/zext->phi update chain.
5233 // 3) Outline common code with createAddRecFromPHI to avoid duplication.
5234 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5235 ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5236 SmallVector<const SCEVPredicate *, 3> Predicates;
5238 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5239 // return an AddRec expression under some predicate.
5241 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5242 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5243 assert(L && "Expecting an integer loop header phi");
5245 // The loop may have multiple entrances or multiple exits; we can analyze
5246 // this phi as an addrec if it has a unique entry value and a unique
5248 Value *BEValueV = nullptr, *StartValueV = nullptr;
5249 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5250 Value *V = PN->getIncomingValue(i);
5251 if (L->contains(PN->getIncomingBlock(i))) {
5254 } else if (BEValueV != V) {
5258 } else if (!StartValueV) {
5260 } else if (StartValueV != V) {
5261 StartValueV = nullptr;
5265 if (!BEValueV || !StartValueV)
5268 const SCEV *BEValue = getSCEV(BEValueV);
5270 // If the value coming around the backedge is an add with the symbolic
5271 // value we just inserted, possibly with casts that we can ignore under
5272 // an appropriate runtime guard, then we found a simple induction variable!
5273 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5277 // If there is a single occurrence of the symbolic value, possibly
5278 // casted, replace it with a recurrence.
5279 unsigned FoundIndex = Add->getNumOperands();
5280 Type *TruncTy = nullptr;
5282 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5284 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5285 if (FoundIndex == e) {
5290 if (FoundIndex == Add->getNumOperands())
5293 // Create an add with everything but the specified operand.
5294 SmallVector<const SCEV *, 8> Ops;
5295 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5296 if (i != FoundIndex)
5297 Ops.push_back(Add->getOperand(i));
5298 const SCEV *Accum = getAddExpr(Ops);
5300 // The runtime checks will not be valid if the step amount is
5301 // varying inside the loop.
5302 if (!isLoopInvariant(Accum, L))
5305 // *** Part2: Create the predicates
5307 // Analysis was successful: we have a phi-with-cast pattern for which we
5308 // can return an AddRec expression under the following predicates:
5310 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5311 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5312 // P2: An Equal predicate that guarantees that
5313 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5314 // P3: An Equal predicate that guarantees that
5315 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5317 // As we next prove, the above predicates guarantee that:
5318 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5321 // More formally, we want to prove that:
5322 // Expr(i+1) = Start + (i+1) * Accum
5323 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5326 // 1) Expr(0) = Start
5327 // 2) Expr(1) = Start + Accum
5328 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5329 // 3) Induction hypothesis (step i):
5330 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5334 // = Start + (i+1)*Accum
5335 // = (Start + i*Accum) + Accum
5336 // = Expr(i) + Accum
5337 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5340 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5342 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5343 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5344 // + Accum :: from P3
5346 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5347 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5349 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5350 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5352 // By induction, the same applies to all iterations 1<=i<n:
5355 // Create a truncated addrec for which we will add a no overflow check (P1).
5356 const SCEV *StartVal = getSCEV(StartValueV);
5357 const SCEV *PHISCEV =
5358 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5359 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5361 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5362 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5363 // will be constant.
5365 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5367 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5368 SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
5369 Signed ? SCEVWrapPredicate::IncrementNSSW
5370 : SCEVWrapPredicate::IncrementNUSW;
5371 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5372 Predicates.push_back(AddRecPred);
5375 // Create the Equal Predicates P2,P3:
5377 // It is possible that the predicates P2 and/or P3 are computable at
5378 // compile time due to StartVal and/or Accum being constants.
5379 // If either one is, then we can check that now and escape if either P2
5382 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5383 // for each of StartVal and Accum
5384 auto getExtendedExpr = [&](const SCEV *Expr,
5385 bool CreateSignExtend) -> const SCEV * {
5386 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5387 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5388 const SCEV *ExtendedExpr =
5389 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5390 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5391 return ExtendedExpr;
5395 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5396 // = getExtendedExpr(Expr)
5397 // Determine whether the predicate P: Expr == ExtendedExpr
5398 // is known to be false at compile time
5399 auto PredIsKnownFalse = [&](const SCEV *Expr,
5400 const SCEV *ExtendedExpr) -> bool {
5401 return Expr != ExtendedExpr &&
5402 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5405 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5406 if (PredIsKnownFalse(StartVal, StartExtended)) {
5407 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5411 // The Step is always Signed (because the overflow checks are either
5413 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5414 if (PredIsKnownFalse(Accum, AccumExtended)) {
5415 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5419 auto AppendPredicate = [&](const SCEV *Expr,
5420 const SCEV *ExtendedExpr) -> void {
5421 if (Expr != ExtendedExpr &&
5422 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5423 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5424 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5425 Predicates.push_back(Pred);
5429 AppendPredicate(StartVal, StartExtended);
5430 AppendPredicate(Accum, AccumExtended);
5432 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5433 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5434 // into NewAR if it will also add the runtime overflow checks specified in
5436 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5438 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5439 std::make_pair(NewAR, Predicates);
5440 // Remember the result of the analysis for this SCEV at this locayyytion.
5441 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5445 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5446 ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
5447 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5448 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5452 // Check to see if we already analyzed this PHI.
5453 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5454 if (I != PredicatedSCEVRewrites.end()) {
5455 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5457 // Analysis was done before and failed to create an AddRec:
5458 if (Rewrite.first == SymbolicPHI)
5460 // Analysis was done before and succeeded to create an AddRec under
5462 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5463 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5467 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5468 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5470 // Record in the cache that the analysis failed
5472 SmallVector<const SCEVPredicate *, 3> Predicates;
5473 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5480 // FIXME: This utility is currently required because the Rewriter currently
5481 // does not rewrite this expression:
5482 // {0, +, (sext ix (trunc iy to ix) to iy)}
5483 // into {0, +, %step},
5484 // even when the following Equal predicate exists:
5485 // "%step == (sext ix (trunc iy to ix) to iy)".
5486 bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5487 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5491 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5492 if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5493 !Preds.implies(SE.getEqualPredicate(Expr2, Expr1)))
5498 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5499 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5504 /// A helper function for createAddRecFromPHI to handle simple cases.
5506 /// This function tries to find an AddRec expression for the simplest (yet most
5507 /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5508 /// If it fails, createAddRecFromPHI will use a more general, but slow,
5509 /// technique for finding the AddRec expression.
5510 const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5512 Value *StartValueV) {
5513 const Loop *L = LI.getLoopFor(PN->getParent());
5514 assert(L && L->getHeader() == PN->getParent());
5515 assert(BEValueV && StartValueV);
5517 auto BO = MatchBinaryOp(BEValueV, DT);
5521 if (BO->Opcode != Instruction::Add)
5524 const SCEV *Accum = nullptr;
5525 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5526 Accum = getSCEV(BO->RHS);
5527 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5528 Accum = getSCEV(BO->LHS);
5533 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5535 Flags = setFlags(Flags, SCEV::FlagNUW);
5537 Flags = setFlags(Flags, SCEV::FlagNSW);
5539 const SCEV *StartVal = getSCEV(StartValueV);
5540 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5541 insertValueToMap(PN, PHISCEV);
5543 // We can add Flags to the post-inc expression only if we
5544 // know that it is *undefined behavior* for BEValueV to
5546 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5547 assert(isLoopInvariant(Accum, L) &&
5548 "Accum is defined outside L, but is not invariant?");
5549 if (isAddRecNeverPoison(BEInst, L))
5550 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5556 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5557 const Loop *L = LI.getLoopFor(PN->getParent());
5558 if (!L || L->getHeader() != PN->getParent())
5561 // The loop may have multiple entrances or multiple exits; we can analyze
5562 // this phi as an addrec if it has a unique entry value and a unique
5564 Value *BEValueV = nullptr, *StartValueV = nullptr;
5565 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5566 Value *V = PN->getIncomingValue(i);
5567 if (L->contains(PN->getIncomingBlock(i))) {
5570 } else if (BEValueV != V) {
5574 } else if (!StartValueV) {
5576 } else if (StartValueV != V) {
5577 StartValueV = nullptr;
5581 if (!BEValueV || !StartValueV)
5584 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5585 "PHI node already processed?");
5587 // First, try to find AddRec expression without creating a fictituos symbolic
5589 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5592 // Handle PHI node value symbolically.
5593 const SCEV *SymbolicName = getUnknown(PN);
5594 insertValueToMap(PN, SymbolicName);
5596 // Using this symbolic name for the PHI, analyze the value coming around
5598 const SCEV *BEValue = getSCEV(BEValueV);
5600 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5601 // has a special value for the first iteration of the loop.
5603 // If the value coming around the backedge is an add with the symbolic
5604 // value we just inserted, then we found a simple induction variable!
5605 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5606 // If there is a single occurrence of the symbolic value, replace it
5607 // with a recurrence.
5608 unsigned FoundIndex = Add->getNumOperands();
5609 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5610 if (Add->getOperand(i) == SymbolicName)
5611 if (FoundIndex == e) {
5616 if (FoundIndex != Add->getNumOperands()) {
5617 // Create an add with everything but the specified operand.
5618 SmallVector<const SCEV *, 8> Ops;
5619 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5620 if (i != FoundIndex)
5621 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5623 const SCEV *Accum = getAddExpr(Ops);
5625 // This is not a valid addrec if the step amount is varying each
5626 // loop iteration, but is not itself an addrec in this loop.
5627 if (isLoopInvariant(Accum, L) ||
5628 (isa<SCEVAddRecExpr>(Accum) &&
5629 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5630 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5632 if (auto BO = MatchBinaryOp(BEValueV, DT)) {
5633 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5635 Flags = setFlags(Flags, SCEV::FlagNUW);
5637 Flags = setFlags(Flags, SCEV::FlagNSW);
5639 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5640 // If the increment is an inbounds GEP, then we know the address
5641 // space cannot be wrapped around. We cannot make any guarantee
5642 // about signed or unsigned overflow because pointers are
5643 // unsigned but we may have a negative index from the base
5644 // pointer. We can guarantee that no unsigned wrap occurs if the
5645 // indices form a positive value.
5646 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5647 Flags = setFlags(Flags, SCEV::FlagNW);
5649 const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
5650 if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
5651 Flags = setFlags(Flags, SCEV::FlagNUW);
5654 // We cannot transfer nuw and nsw flags from subtraction
5655 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5659 const SCEV *StartVal = getSCEV(StartValueV);
5660 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5662 // Okay, for the entire analysis of this edge we assumed the PHI
5663 // to be symbolic. We now need to go back and purge all of the
5664 // entries for the scalars that use the symbolic expression.
5665 forgetMemoizedResults(SymbolicName);
5666 insertValueToMap(PN, PHISCEV);
5668 // We can add Flags to the post-inc expression only if we
5669 // know that it is *undefined behavior* for BEValueV to
5671 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5672 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5673 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5679 // Otherwise, this could be a loop like this:
5680 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5681 // In this case, j = {1,+,1} and BEValue is j.
5682 // Because the other in-value of i (0) fits the evolution of BEValue
5683 // i really is an addrec evolution.
5685 // We can generalize this saying that i is the shifted value of BEValue
5686 // by one iteration:
5687 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5688 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5689 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5690 if (Shifted != getCouldNotCompute() &&
5691 Start != getCouldNotCompute()) {
5692 const SCEV *StartVal = getSCEV(StartValueV);
5693 if (Start == StartVal) {
5694 // Okay, for the entire analysis of this edge we assumed the PHI
5695 // to be symbolic. We now need to go back and purge all of the
5696 // entries for the scalars that use the symbolic expression.
5697 forgetMemoizedResults(SymbolicName);
5698 insertValueToMap(PN, Shifted);
5704 // Remove the temporary PHI node SCEV that has been inserted while intending
5705 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5706 // as it will prevent later (possibly simpler) SCEV expressions to be added
5707 // to the ValueExprMap.
5708 eraseValueFromMap(PN);
5713 // Checks if the SCEV S is available at BB. S is considered available at BB
5714 // if S can be materialized at BB without introducing a fault.
5715 static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
5717 struct CheckAvailable {
5718 bool TraversalDone = false;
5719 bool Available = true;
5721 const Loop *L = nullptr; // The loop BB is in (can be nullptr)
5722 BasicBlock *BB = nullptr;
5725 CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
5726 : L(L), BB(BB), DT(DT) {}
5728 bool setUnavailable() {
5729 TraversalDone = true;
5734 bool follow(const SCEV *S) {
5735 switch (S->getSCEVType()) {
5747 case scSequentialUMinExpr:
5748 // These expressions are available if their operand(s) is/are.
5751 case scAddRecExpr: {
5752 // We allow add recurrences that are on the loop BB is in, or some
5753 // outer loop. This guarantees availability because the value of the
5754 // add recurrence at BB is simply the "current" value of the induction
5755 // variable. We can relax this in the future; for instance an add
5756 // recurrence on a sibling dominating loop is also available at BB.
5757 const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
5758 if (L && (ARLoop == L || ARLoop->contains(L)))
5761 return setUnavailable();
5765 // For SCEVUnknown, we check for simple dominance.
5766 const auto *SU = cast<SCEVUnknown>(S);
5767 Value *V = SU->getValue();
5769 if (isa<Argument>(V))
5772 if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
5775 return setUnavailable();
5779 case scCouldNotCompute:
5780 // We do not try to smart about these at all.
5781 return setUnavailable();
5783 llvm_unreachable("Unknown SCEV kind!");
5786 bool isDone() { return TraversalDone; }
5789 CheckAvailable CA(L, BB, DT);
5790 SCEVTraversal<CheckAvailable> ST(CA);
5793 return CA.Available;
5796 // Try to match a control flow sequence that branches out at BI and merges back
5797 // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5799 static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
5800 Value *&C, Value *&LHS, Value *&RHS) {
5801 C = BI->getCondition();
5803 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5804 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5806 if (!LeftEdge.isSingleEdge())
5809 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5811 Use &LeftUse = Merge->getOperandUse(0);
5812 Use &RightUse = Merge->getOperandUse(1);
5814 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5820 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5829 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5831 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5832 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5833 const Loop *L = LI.getLoopFor(PN->getParent());
5835 // We don't want to break LCSSA, even in a SCEV expression tree.
5836 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
5837 if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
5842 // br %cond, label %left, label %right
5848 // V = phi [ %x, %left ], [ %y, %right ]
5850 // as "select %cond, %x, %y"
5852 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5853 assert(IDom && "At least the entry block should dominate PN");
5855 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5856 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5858 if (BI && BI->isConditional() &&
5859 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
5860 IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
5861 IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
5862 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
5868 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
5869 if (const SCEV *S = createAddRecFromPHI(PN))
5872 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
5875 // If the PHI has a single incoming value, follow that value, unless the
5876 // PHI's incoming blocks are in a different loop, in which case doing so
5877 // risks breaking LCSSA form. Instcombine would normally zap these, but
5878 // it doesn't have DominatorTree information, so it may miss cases.
5879 if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
5880 if (LI.replacementPreservesLCSSAForm(PN, V))
5883 // If it's not a loop phi, we can't handle it yet.
5884 return getUnknown(PN);
5887 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
5891 // Handle "constant" branch or select. This can occur for instance when a
5892 // loop pass transforms an inner loop and moves on to process the outer loop.
5893 if (auto *CI = dyn_cast<ConstantInt>(Cond))
5894 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
5896 // Try to match some simple smax or umax patterns.
5897 auto *ICI = dyn_cast<ICmpInst>(Cond);
5899 return getUnknown(I);
5901 Value *LHS = ICI->getOperand(0);
5902 Value *RHS = ICI->getOperand(1);
5904 switch (ICI->getPredicate()) {
5905 case ICmpInst::ICMP_SLT:
5906 case ICmpInst::ICMP_SLE:
5907 case ICmpInst::ICMP_ULT:
5908 case ICmpInst::ICMP_ULE:
5909 std::swap(LHS, RHS);
5911 case ICmpInst::ICMP_SGT:
5912 case ICmpInst::ICMP_SGE:
5913 case ICmpInst::ICMP_UGT:
5914 case ICmpInst::ICMP_UGE:
5915 // a > b ? a+x : b+x -> max(a, b)+x
5916 // a > b ? b+x : a+x -> min(a, b)+x
5917 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
5918 bool Signed = ICI->isSigned();
5919 const SCEV *LA = getSCEV(TrueVal);
5920 const SCEV *RA = getSCEV(FalseVal);
5921 const SCEV *LS = getSCEV(LHS);
5922 const SCEV *RS = getSCEV(RHS);
5923 if (LA->getType()->isPointerTy()) {
5924 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
5925 // Need to make sure we can't produce weird expressions involving
5926 // negated pointers.
5927 if (LA == LS && RA == RS)
5928 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
5929 if (LA == RS && RA == LS)
5930 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
5932 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
5933 if (Op->getType()->isPointerTy()) {
5934 Op = getLosslessPtrToIntExpr(Op);
5935 if (isa<SCEVCouldNotCompute>(Op))
5939 Op = getNoopOrSignExtend(Op, I->getType());
5941 Op = getNoopOrZeroExtend(Op, I->getType());
5944 LS = CoerceOperand(LS);
5945 RS = CoerceOperand(RS);
5946 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
5948 const SCEV *LDiff = getMinusSCEV(LA, LS);
5949 const SCEV *RDiff = getMinusSCEV(RA, RS);
5951 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
5953 LDiff = getMinusSCEV(LA, RS);
5954 RDiff = getMinusSCEV(RA, LS);
5956 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
5960 case ICmpInst::ICMP_NE:
5961 // n != 0 ? n+x : 1+x -> umax(n, 1)+x
5962 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
5963 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
5964 const SCEV *One = getOne(I->getType());
5965 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
5966 const SCEV *LA = getSCEV(TrueVal);
5967 const SCEV *RA = getSCEV(FalseVal);
5968 const SCEV *LDiff = getMinusSCEV(LA, LS);
5969 const SCEV *RDiff = getMinusSCEV(RA, One);
5971 return getAddExpr(getUMaxExpr(One, LS), LDiff);
5974 case ICmpInst::ICMP_EQ:
5975 // n == 0 ? 1+x : n+x -> umax(n, 1)+x
5976 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
5977 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
5978 const SCEV *One = getOne(I->getType());
5979 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
5980 const SCEV *LA = getSCEV(TrueVal);
5981 const SCEV *RA = getSCEV(FalseVal);
5982 const SCEV *LDiff = getMinusSCEV(LA, One);
5983 const SCEV *RDiff = getMinusSCEV(RA, LS);
5985 return getAddExpr(getUMaxExpr(One, LS), LDiff);
5992 return getUnknown(I);
5995 /// Expand GEP instructions into add and multiply operations. This allows them
5996 /// to be analyzed by regular SCEV code.
5997 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
5998 // Don't attempt to analyze GEPs over unsized objects.
5999 if (!GEP->getSourceElementType()->isSized())
6000 return getUnknown(GEP);
6002 SmallVector<const SCEV *, 4> IndexExprs;
6003 for (Value *Index : GEP->indices())
6004 IndexExprs.push_back(getSCEV(Index));
6005 return getGEPExpr(GEP, IndexExprs);
6008 uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
6009 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6010 return C->getAPInt().countTrailingZeros();
6012 if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S))
6013 return GetMinTrailingZeros(I->getOperand());
6015 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
6016 return std::min(GetMinTrailingZeros(T->getOperand()),
6017 (uint32_t)getTypeSizeInBits(T->getType()));
6019 if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
6020 uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
6021 return OpRes == getTypeSizeInBits(E->getOperand()->getType())
6022 ? getTypeSizeInBits(E->getType())
6026 if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
6027 uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
6028 return OpRes == getTypeSizeInBits(E->getOperand()->getType())
6029 ? getTypeSizeInBits(E->getType())
6033 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
6034 // The result is the min of all operands results.
6035 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
6036 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
6037 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
6041 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
6042 // The result is the sum of all operands results.
6043 uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
6044 uint32_t BitWidth = getTypeSizeInBits(M->getType());
6045 for (unsigned i = 1, e = M->getNumOperands();
6046 SumOpRes != BitWidth && i != e; ++i)
6048 std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
6052 if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
6053 // The result is the min of all operands results.
6054 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
6055 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
6056 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
6060 if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
6061 // The result is the min of all operands results.
6062 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
6063 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
6064 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
6068 if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
6069 // The result is the min of all operands results.
6070 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
6071 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
6072 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
6076 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
6077 // For a SCEVUnknown, ask ValueTracking.
6078 KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
6079 return Known.countMinTrailingZeros();
6086 uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
6087 auto I = MinTrailingZerosCache.find(S);
6088 if (I != MinTrailingZerosCache.end())
6091 uint32_t Result = GetMinTrailingZerosImpl(S);
6092 auto InsertPair = MinTrailingZerosCache.insert({S, Result});
6093 assert(InsertPair.second && "Should insert a new key");
6094 return InsertPair.first->second;
6097 /// Helper method to assign a range to V from metadata present in the IR.
6098 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6099 if (Instruction *I = dyn_cast<Instruction>(V))
6100 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6101 return getConstantRangeFromMetadata(*MD);
6106 void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec,
6107 SCEV::NoWrapFlags Flags) {
6108 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6109 AddRec->setNoWrapFlags(Flags);
6110 UnsignedRanges.erase(AddRec);
6111 SignedRanges.erase(AddRec);
6115 ConstantRange ScalarEvolution::
6116 getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6117 const DataLayout &DL = getDataLayout();
6119 unsigned BitWidth = getTypeSizeInBits(U->getType());
6120 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6122 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6123 // use information about the trip count to improve our available range. Note
6124 // that the trip count independent cases are already handled by known bits.
6125 // WARNING: The definition of recurrence used here is subtly different than
6126 // the one used by AddRec (and thus most of this file). Step is allowed to
6127 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6128 // and other addrecs in the same loop (for non-affine addrecs). The code
6129 // below intentionally handles the case where step is not loop invariant.
6130 auto *P = dyn_cast<PHINode>(U->getValue());
6134 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6135 // even the values that are not available in these blocks may come from them,
6136 // and this leads to false-positive recurrence test.
6137 for (auto *Pred : predecessors(P->getParent()))
6138 if (!DT.isReachableFromEntry(Pred))
6142 Value *Start, *Step;
6143 if (!matchSimpleRecurrence(P, BO, Start, Step))
6146 // If we found a recurrence in reachable code, we must be in a loop. Note
6147 // that BO might be in some subloop of L, and that's completely okay.
6148 auto *L = LI.getLoopFor(P->getParent());
6149 assert(L && L->getHeader() == P->getParent());
6150 if (!L->contains(BO->getParent()))
6151 // NOTE: This bailout should be an assert instead. However, asserting
6152 // the condition here exposes a case where LoopFusion is querying SCEV
6153 // with malformed loop information during the midst of the transform.
6154 // There doesn't appear to be an obvious fix, so for the moment bailout
6155 // until the caller issue can be fixed. PR49566 tracks the bug.
6158 // TODO: Extend to other opcodes such as mul, and div
6159 switch (BO->getOpcode()) {
6162 case Instruction::AShr:
6163 case Instruction::LShr:
6164 case Instruction::Shl:
6168 if (BO->getOperand(0) != P)
6169 // TODO: Handle the power function forms some day.
6172 unsigned TC = getSmallConstantMaxTripCount(L);
6173 if (!TC || TC >= BitWidth)
6176 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6177 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6178 assert(KnownStart.getBitWidth() == BitWidth &&
6179 KnownStep.getBitWidth() == BitWidth);
6181 // Compute total shift amount, being careful of overflow and bitwidths.
6182 auto MaxShiftAmt = KnownStep.getMaxValue();
6183 APInt TCAP(BitWidth, TC-1);
6184 bool Overflow = false;
6185 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6189 switch (BO->getOpcode()) {
6191 llvm_unreachable("filtered out above");
6192 case Instruction::AShr: {
6193 // For each ashr, three cases:
6194 // shift = 0 => unchanged value
6195 // saturation => 0 or -1
6196 // other => a value closer to zero (of the same sign)
6197 // Thus, the end value is closer to zero than the start.
6198 auto KnownEnd = KnownBits::ashr(KnownStart,
6199 KnownBits::makeConstant(TotalShift));
6200 if (KnownStart.isNonNegative())
6201 // Analogous to lshr (simply not yet canonicalized)
6202 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6203 KnownStart.getMaxValue() + 1);
6204 if (KnownStart.isNegative())
6205 // End >=u Start && End <=s Start
6206 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6207 KnownEnd.getMaxValue() + 1);
6210 case Instruction::LShr: {
6211 // For each lshr, three cases:
6212 // shift = 0 => unchanged value
6214 // other => a smaller positive number
6215 // Thus, the low end of the unsigned range is the last value produced.
6216 auto KnownEnd = KnownBits::lshr(KnownStart,
6217 KnownBits::makeConstant(TotalShift));
6218 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6219 KnownStart.getMaxValue() + 1);
6221 case Instruction::Shl: {
6222 // Iff no bits are shifted out, value increases on every shift.
6223 auto KnownEnd = KnownBits::shl(KnownStart,
6224 KnownBits::makeConstant(TotalShift));
6225 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6226 return ConstantRange(KnownStart.getMinValue(),
6227 KnownEnd.getMaxValue() + 1);
6234 /// Determine the range for a particular SCEV. If SignHint is
6235 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6236 /// with a "cleaner" unsigned (resp. signed) representation.
6237 const ConstantRange &
6238 ScalarEvolution::getRangeRef(const SCEV *S,
6239 ScalarEvolution::RangeSignHint SignHint) {
6240 DenseMap<const SCEV *, ConstantRange> &Cache =
6241 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6243 ConstantRange::PreferredRangeType RangeType =
6244 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
6245 ? ConstantRange::Unsigned : ConstantRange::Signed;
6247 // See if we've computed this range already.
6248 DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
6249 if (I != Cache.end())
6252 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6253 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6255 unsigned BitWidth = getTypeSizeInBits(S->getType());
6256 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6257 using OBO = OverflowingBinaryOperator;
6259 // If the value has known zeros, the maximum value will have those known zeros
6261 uint32_t TZ = GetMinTrailingZeros(S);
6263 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
6264 ConservativeResult =
6265 ConstantRange(APInt::getMinValue(BitWidth),
6266 APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
6268 ConservativeResult = ConstantRange(
6269 APInt::getSignedMinValue(BitWidth),
6270 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6273 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
6274 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
6275 unsigned WrapType = OBO::AnyWrap;
6276 if (Add->hasNoSignedWrap())
6277 WrapType |= OBO::NoSignedWrap;
6278 if (Add->hasNoUnsignedWrap())
6279 WrapType |= OBO::NoUnsignedWrap;
6280 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6281 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint),
6282 WrapType, RangeType);
6283 return setRange(Add, SignHint,
6284 ConservativeResult.intersectWith(X, RangeType));
6287 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
6288 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
6289 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6290 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
6291 return setRange(Mul, SignHint,
6292 ConservativeResult.intersectWith(X, RangeType));
6295 if (isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) {
6297 switch (S->getSCEVType()) {
6299 ID = Intrinsic::umax;
6302 ID = Intrinsic::smax;
6305 case scSequentialUMinExpr:
6306 ID = Intrinsic::umin;
6309 ID = Intrinsic::smin;
6312 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6315 const auto *NAry = cast<SCEVNAryExpr>(S);
6316 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint);
6317 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6318 X = X.intrinsic(ID, {X, getRangeRef(NAry->getOperand(i), SignHint)});
6319 return setRange(S, SignHint,
6320 ConservativeResult.intersectWith(X, RangeType));
6323 if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
6324 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
6325 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
6326 return setRange(UDiv, SignHint,
6327 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6330 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
6331 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
6332 return setRange(ZExt, SignHint,
6333 ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
6337 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
6338 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
6339 return setRange(SExt, SignHint,
6340 ConservativeResult.intersectWith(X.signExtend(BitWidth),
6344 if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
6345 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
6346 return setRange(PtrToInt, SignHint, X);
6349 if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
6350 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
6351 return setRange(Trunc, SignHint,
6352 ConservativeResult.intersectWith(X.truncate(BitWidth),
6356 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
6357 // If there's no unsigned wrap, the value will never be less than its
6359 if (AddRec->hasNoUnsignedWrap()) {
6360 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6361 if (!UnsignedMinValue.isZero())
6362 ConservativeResult = ConservativeResult.intersectWith(
6363 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6366 // If there's no signed wrap, and all the operands except initial value have
6367 // the same sign or zero, the value won't ever be:
6368 // 1: smaller than initial value if operands are non negative,
6369 // 2: bigger than initial value if operands are non positive.
6370 // For both cases, value can not cross signed min/max boundary.
6371 if (AddRec->hasNoSignedWrap()) {
6372 bool AllNonNeg = true;
6373 bool AllNonPos = true;
6374 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6375 if (!isKnownNonNegative(AddRec->getOperand(i)))
6377 if (!isKnownNonPositive(AddRec->getOperand(i)))
6381 ConservativeResult = ConservativeResult.intersectWith(
6382 ConstantRange::getNonEmpty(getSignedRangeMin(AddRec->getStart()),
6383 APInt::getSignedMinValue(BitWidth)),
6386 ConservativeResult = ConservativeResult.intersectWith(
6387 ConstantRange::getNonEmpty(
6388 APInt::getSignedMinValue(BitWidth),
6389 getSignedRangeMax(AddRec->getStart()) + 1),
6393 // TODO: non-affine addrec
6394 if (AddRec->isAffine()) {
6395 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop());
6396 if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
6397 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
6398 auto RangeFromAffine = getRangeForAffineAR(
6399 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
6401 ConservativeResult =
6402 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6404 auto RangeFromFactoring = getRangeViaFactoring(
6405 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
6407 ConservativeResult =
6408 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6411 // Now try symbolic BE count and more powerful methods.
6412 if (UseExpensiveRangeSharpening) {
6413 const SCEV *SymbolicMaxBECount =
6414 getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
6415 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6416 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
6417 AddRec->hasNoSelfWrap()) {
6418 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6419 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6420 ConservativeResult =
6421 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6426 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6429 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
6431 // Check if the IR explicitly contains !range metadata.
6432 Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
6433 if (MDRange.hasValue())
6434 ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue(),
6437 // Use facts about recurrences in the underlying IR. Note that add
6438 // recurrences are AddRecExprs and thus don't hit this path. This
6439 // primarily handles shift recurrences.
6440 auto CR = getRangeForUnknownRecurrence(U);
6441 ConservativeResult = ConservativeResult.intersectWith(CR);
6443 // See if ValueTracking can give us a useful range.
6444 const DataLayout &DL = getDataLayout();
6445 KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
6446 if (Known.getBitWidth() != BitWidth)
6447 Known = Known.zextOrTrunc(BitWidth);
6449 // ValueTracking may be able to compute a tighter result for the number of
6450 // sign bits than for the value of those sign bits.
6451 unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
6452 if (U->getType()->isPointerTy()) {
6453 // If the pointer size is larger than the index size type, this can cause
6454 // NS to be larger than BitWidth. So compensate for this.
6455 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6456 int ptrIdxDiff = ptrSize - BitWidth;
6457 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6462 // If we know any of the sign bits, we know all of the sign bits.
6463 if (!Known.Zero.getHiBits(NS).isZero())
6464 Known.Zero.setHighBits(NS);
6465 if (!Known.One.getHiBits(NS).isZero())
6466 Known.One.setHighBits(NS);
6469 if (Known.getMinValue() != Known.getMaxValue() + 1)
6470 ConservativeResult = ConservativeResult.intersectWith(
6471 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6474 ConservativeResult = ConservativeResult.intersectWith(
6475 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6476 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6479 // A range of Phi is a subset of union of all ranges of its input.
6480 if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
6481 // Make sure that we do not run over cycled Phis.
6482 if (PendingPhiRanges.insert(Phi).second) {
6483 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6484 for (auto &Op : Phi->operands()) {
6485 auto OpRange = getRangeRef(getSCEV(Op), SignHint);
6486 RangeFromOps = RangeFromOps.unionWith(OpRange);
6487 // No point to continue if we already have a full set.
6488 if (RangeFromOps.isFullSet())
6491 ConservativeResult =
6492 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6493 bool Erased = PendingPhiRanges.erase(Phi);
6494 assert(Erased && "Failed to erase Phi properly?");
6499 return setRange(U, SignHint, std::move(ConservativeResult));
6502 return setRange(S, SignHint, std::move(ConservativeResult));
6505 // Given a StartRange, Step and MaxBECount for an expression compute a range of
6506 // values that the expression can take. Initially, the expression has a value
6507 // from StartRange and then is changed by Step up to MaxBECount times. Signed
6508 // argument defines if we treat Step as signed or unsigned.
6509 static ConstantRange getRangeForAffineARHelper(APInt Step,
6510 const ConstantRange &StartRange,
6511 const APInt &MaxBECount,
6512 unsigned BitWidth, bool Signed) {
6513 // If either Step or MaxBECount is 0, then the expression won't change, and we
6514 // just need to return the initial range.
6515 if (Step == 0 || MaxBECount == 0)
6518 // If we don't know anything about the initial value (i.e. StartRange is
6519 // FullRange), then we don't know anything about the final range either.
6520 // Return FullRange.
6521 if (StartRange.isFullSet())
6522 return ConstantRange::getFull(BitWidth);
6524 // If Step is signed and negative, then we use its absolute value, but we also
6525 // note that we're moving in the opposite direction.
6526 bool Descending = Signed && Step.isNegative();
6529 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6530 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6531 // This equations hold true due to the well-defined wrap-around behavior of
6535 // Check if Offset is more than full span of BitWidth. If it is, the
6536 // expression is guaranteed to overflow.
6537 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6538 return ConstantRange::getFull(BitWidth);
6540 // Offset is by how much the expression can change. Checks above guarantee no
6542 APInt Offset = Step * MaxBECount;
6544 // Minimum value of the final range will match the minimal value of StartRange
6545 // if the expression is increasing and will be decreased by Offset otherwise.
6546 // Maximum value of the final range will match the maximal value of StartRange
6547 // if the expression is decreasing and will be increased by Offset otherwise.
6548 APInt StartLower = StartRange.getLower();
6549 APInt StartUpper = StartRange.getUpper() - 1;
6550 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6551 : (StartUpper + std::move(Offset));
6553 // It's possible that the new minimum/maximum value will fall into the initial
6554 // range (due to wrap around). This means that the expression can take any
6555 // value in this bitwidth, and we have to return full range.
6556 if (StartRange.contains(MovedBoundary))
6557 return ConstantRange::getFull(BitWidth);
6560 Descending ? std::move(MovedBoundary) : std::move(StartLower);
6562 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
6565 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
6566 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
6569 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
6571 const SCEV *MaxBECount,
6572 unsigned BitWidth) {
6573 assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
6574 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
6577 MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
6578 APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
6580 // First, consider step signed.
6581 ConstantRange StartSRange = getSignedRange(Start);
6582 ConstantRange StepSRange = getSignedRange(Step);
6584 // If Step can be both positive and negative, we need to find ranges for the
6585 // maximum absolute step values in both directions and union them.
6587 getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
6588 MaxBECountValue, BitWidth, /* Signed = */ true);
6589 SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
6590 StartSRange, MaxBECountValue,
6591 BitWidth, /* Signed = */ true));
6593 // Next, consider step unsigned.
6594 ConstantRange UR = getRangeForAffineARHelper(
6595 getUnsignedRangeMax(Step), getUnsignedRange(Start),
6596 MaxBECountValue, BitWidth, /* Signed = */ false);
6598 // Finally, intersect signed and unsigned ranges.
6599 return SR.intersectWith(UR, ConstantRange::Smallest);
6602 ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
6603 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
6604 ScalarEvolution::RangeSignHint SignHint) {
6605 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
6606 assert(AddRec->hasNoSelfWrap() &&
6607 "This only works for non-self-wrapping AddRecs!");
6608 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
6609 const SCEV *Step = AddRec->getStepRecurrence(*this);
6610 // Only deal with constant step to save compile time.
6611 if (!isa<SCEVConstant>(Step))
6612 return ConstantRange::getFull(BitWidth);
6613 // Let's make sure that we can prove that we do not self-wrap during
6614 // MaxBECount iterations. We need this because MaxBECount is a maximum
6615 // iteration count estimate, and we might infer nw from some exit for which we
6616 // do not know max exit count (or any other side reasoning).
6617 // TODO: Turn into assert at some point.
6618 if (getTypeSizeInBits(MaxBECount->getType()) >
6619 getTypeSizeInBits(AddRec->getType()))
6620 return ConstantRange::getFull(BitWidth);
6621 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
6622 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
6623 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
6624 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
6625 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
6626 MaxItersWithoutWrap))
6627 return ConstantRange::getFull(BitWidth);
6629 ICmpInst::Predicate LEPred =
6630 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
6631 ICmpInst::Predicate GEPred =
6632 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
6633 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
6635 // We know that there is no self-wrap. Let's take Start and End values and
6636 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
6637 // the iteration. They either lie inside the range [Min(Start, End),
6638 // Max(Start, End)] or outside it:
6640 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
6641 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
6643 // No self wrap flag guarantees that the intermediate values cannot be BOTH
6644 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
6645 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
6646 // Start <= End and step is positive, or Start >= End and step is negative.
6647 const SCEV *Start = AddRec->getStart();
6648 ConstantRange StartRange = getRangeRef(Start, SignHint);
6649 ConstantRange EndRange = getRangeRef(End, SignHint);
6650 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
6651 // If they already cover full iteration space, we will know nothing useful
6652 // even if we prove what we want to prove.
6653 if (RangeBetween.isFullSet())
6654 return RangeBetween;
6655 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
6656 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
6657 : RangeBetween.isWrappedSet();
6659 return ConstantRange::getFull(BitWidth);
6661 if (isKnownPositive(Step) &&
6662 isKnownPredicateViaConstantRanges(LEPred, Start, End))
6663 return RangeBetween;
6664 else if (isKnownNegative(Step) &&
6665 isKnownPredicateViaConstantRanges(GEPred, Start, End))
6666 return RangeBetween;
6667 return ConstantRange::getFull(BitWidth);
6670 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
6672 const SCEV *MaxBECount,
6673 unsigned BitWidth) {
6674 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
6675 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
6677 struct SelectPattern {
6678 Value *Condition = nullptr;
6682 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
6684 Optional<unsigned> CastOp;
6685 APInt Offset(BitWidth, 0);
6687 assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
6690 // Peel off a constant offset:
6691 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
6692 // In the future we could consider being smarter here and handle
6693 // {Start+Step,+,Step} too.
6694 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
6697 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
6698 S = SA->getOperand(1);
6701 // Peel off a cast operation
6702 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
6703 CastOp = SCast->getSCEVType();
6704 S = SCast->getOperand();
6707 using namespace llvm::PatternMatch;
6709 auto *SU = dyn_cast<SCEVUnknown>(S);
6710 const APInt *TrueVal, *FalseVal;
6712 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
6713 m_APInt(FalseVal)))) {
6714 Condition = nullptr;
6718 TrueValue = *TrueVal;
6719 FalseValue = *FalseVal;
6721 // Re-apply the cast we peeled off earlier
6722 if (CastOp.hasValue())
6725 llvm_unreachable("Unknown SCEV cast type!");
6728 TrueValue = TrueValue.trunc(BitWidth);
6729 FalseValue = FalseValue.trunc(BitWidth);
6732 TrueValue = TrueValue.zext(BitWidth);
6733 FalseValue = FalseValue.zext(BitWidth);
6736 TrueValue = TrueValue.sext(BitWidth);
6737 FalseValue = FalseValue.sext(BitWidth);
6741 // Re-apply the constant offset we peeled off earlier
6742 TrueValue += Offset;
6743 FalseValue += Offset;
6746 bool isRecognized() { return Condition != nullptr; }
6749 SelectPattern StartPattern(*this, BitWidth, Start);
6750 if (!StartPattern.isRecognized())
6751 return ConstantRange::getFull(BitWidth);
6753 SelectPattern StepPattern(*this, BitWidth, Step);
6754 if (!StepPattern.isRecognized())
6755 return ConstantRange::getFull(BitWidth);
6757 if (StartPattern.Condition != StepPattern.Condition) {
6758 // We don't handle this case today; but we could, by considering four
6759 // possibilities below instead of two. I'm not sure if there are cases where
6760 // that will help over what getRange already does, though.
6761 return ConstantRange::getFull(BitWidth);
6764 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
6765 // construct arbitrary general SCEV expressions here. This function is called
6766 // from deep in the call stack, and calling getSCEV (on a sext instruction,
6767 // say) can end up caching a suboptimal value.
6769 // FIXME: without the explicit `this` receiver below, MSVC errors out with
6770 // C2352 and C2512 (otherwise it isn't needed).
6772 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
6773 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
6774 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
6775 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
6777 ConstantRange TrueRange =
6778 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
6779 ConstantRange FalseRange =
6780 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
6782 return TrueRange.unionWith(FalseRange);
6785 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
6786 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
6787 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
6789 // Return early if there are no flags to propagate to the SCEV.
6790 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
6791 if (BinOp->hasNoUnsignedWrap())
6792 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
6793 if (BinOp->hasNoSignedWrap())
6794 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
6795 if (Flags == SCEV::FlagAnyWrap)
6796 return SCEV::FlagAnyWrap;
6798 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
6802 ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
6803 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
6804 return &*AddRec->getLoop()->getHeader()->begin();
6805 if (auto *U = dyn_cast<SCEVUnknown>(S))
6806 if (auto *I = dyn_cast<Instruction>(U->getValue()))
6811 /// Fills \p Ops with unique operands of \p S, if it has operands. If not,
6812 /// \p Ops remains unmodified.
6813 static void collectUniqueOps(const SCEV *S,
6814 SmallVectorImpl<const SCEV *> &Ops) {
6815 SmallPtrSet<const SCEV *, 4> Unique;
6816 auto InsertUnique = [&](const SCEV *S) {
6817 if (Unique.insert(S).second)
6820 if (auto *S2 = dyn_cast<SCEVCastExpr>(S))
6821 for (auto *Op : S2->operands())
6823 else if (auto *S2 = dyn_cast<SCEVNAryExpr>(S))
6824 for (auto *Op : S2->operands())
6826 else if (auto *S2 = dyn_cast<SCEVUDivExpr>(S))
6827 for (auto *Op : S2->operands())
6832 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
6835 // Do a bounded search of the def relation of the requested SCEVs.
6836 SmallSet<const SCEV *, 16> Visited;
6837 SmallVector<const SCEV *> Worklist;
6838 auto pushOp = [&](const SCEV *S) {
6839 if (!Visited.insert(S).second)
6841 // Threshold of 30 here is arbitrary.
6842 if (Visited.size() > 30) {
6846 Worklist.push_back(S);
6852 const Instruction *Bound = nullptr;
6853 while (!Worklist.empty()) {
6854 auto *S = Worklist.pop_back_val();
6855 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
6856 if (!Bound || DT.dominates(Bound, DefI))
6859 SmallVector<const SCEV *, 4> Ops;
6860 collectUniqueOps(S, Ops);
6861 for (auto *Op : Ops)
6865 return Bound ? Bound : &*F.getEntryBlock().begin();
6869 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
6871 return getDefiningScopeBound(Ops, Discard);
6874 bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
6875 const Instruction *B) {
6876 if (A->getParent() == B->getParent() &&
6877 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
6881 auto *BLoop = LI.getLoopFor(B->getParent());
6882 if (BLoop && BLoop->getHeader() == B->getParent() &&
6883 BLoop->getLoopPreheader() == A->getParent() &&
6884 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
6885 A->getParent()->end()) &&
6886 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
6893 bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
6894 // Only proceed if we can prove that I does not yield poison.
6895 if (!programUndefinedIfPoison(I))
6898 // At this point we know that if I is executed, then it does not wrap
6899 // according to at least one of NSW or NUW. If I is not executed, then we do
6900 // not know if the calculation that I represents would wrap. Multiple
6901 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
6902 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
6903 // derived from other instructions that map to the same SCEV. We cannot make
6904 // that guarantee for cases where I is not executed. So we need to find a
6905 // upper bound on the defining scope for the SCEV, and prove that I is
6906 // executed every time we enter that scope. When the bounding scope is a
6907 // loop (the common case), this is equivalent to proving I executes on every
6908 // iteration of that loop.
6909 SmallVector<const SCEV *> SCEVOps;
6910 for (const Use &Op : I->operands()) {
6911 // I could be an extractvalue from a call to an overflow intrinsic.
6912 // TODO: We can do better here in some cases.
6913 if (isSCEVable(Op->getType()))
6914 SCEVOps.push_back(getSCEV(Op));
6916 auto *DefI = getDefiningScopeBound(SCEVOps);
6917 return isGuaranteedToTransferExecutionTo(DefI, I);
6920 bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
6921 // If we know that \c I can never be poison period, then that's enough.
6922 if (isSCEVExprNeverPoison(I))
6925 // For an add recurrence specifically, we assume that infinite loops without
6926 // side effects are undefined behavior, and then reason as follows:
6928 // If the add recurrence is poison in any iteration, it is poison on all
6929 // future iterations (since incrementing poison yields poison). If the result
6930 // of the add recurrence is fed into the loop latch condition and the loop
6931 // does not contain any throws or exiting blocks other than the latch, we now
6932 // have the ability to "choose" whether the backedge is taken or not (by
6933 // choosing a sufficiently evil value for the poison feeding into the branch)
6934 // for every iteration including and after the one in which \p I first became
6935 // poison. There are two possibilities (let's call the iteration in which \p
6936 // I first became poison as K):
6938 // 1. In the set of iterations including and after K, the loop body executes
6939 // no side effects. In this case executing the backege an infinte number
6940 // of times will yield undefined behavior.
6942 // 2. In the set of iterations including and after K, the loop body executes
6943 // at least one side effect. In this case, that specific instance of side
6944 // effect is control dependent on poison, which also yields undefined
6947 auto *ExitingBB = L->getExitingBlock();
6948 auto *LatchBB = L->getLoopLatch();
6949 if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
6952 SmallPtrSet<const Instruction *, 16> Pushed;
6953 SmallVector<const Instruction *, 8> PoisonStack;
6955 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
6956 // things that are known to be poison under that assumption go on the
6959 PoisonStack.push_back(I);
6961 bool LatchControlDependentOnPoison = false;
6962 while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
6963 const Instruction *Poison = PoisonStack.pop_back_val();
6965 for (auto *PoisonUser : Poison->users()) {
6966 if (propagatesPoison(cast<Operator>(PoisonUser))) {
6967 if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
6968 PoisonStack.push_back(cast<Instruction>(PoisonUser));
6969 } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
6970 assert(BI->isConditional() && "Only possibility!");
6971 if (BI->getParent() == LatchBB) {
6972 LatchControlDependentOnPoison = true;
6979 return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
6982 ScalarEvolution::LoopProperties
6983 ScalarEvolution::getLoopProperties(const Loop *L) {
6984 using LoopProperties = ScalarEvolution::LoopProperties;
6986 auto Itr = LoopPropertiesCache.find(L);
6987 if (Itr == LoopPropertiesCache.end()) {
6988 auto HasSideEffects = [](Instruction *I) {
6989 if (auto *SI = dyn_cast<StoreInst>(I))
6990 return !SI->isSimple();
6992 return I->mayThrow() || I->mayWriteToMemory();
6995 LoopProperties LP = {/* HasNoAbnormalExits */ true,
6996 /*HasNoSideEffects*/ true};
6998 for (auto *BB : L->getBlocks())
6999 for (auto &I : *BB) {
7000 if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7001 LP.HasNoAbnormalExits = false;
7002 if (HasSideEffects(&I))
7003 LP.HasNoSideEffects = false;
7004 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7005 break; // We're already as pessimistic as we can get.
7008 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7009 assert(InsertPair.second && "We just checked!");
7010 Itr = InsertPair.first;
7016 bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
7017 // A mustprogress loop without side effects must be finite.
7018 // TODO: The check used here is very conservative. It's only *specific*
7019 // side effects which are well defined in infinite loops.
7020 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7023 const SCEV *ScalarEvolution::createSCEV(Value *V) {
7024 if (!isSCEVable(V->getType()))
7025 return getUnknown(V);
7027 if (Instruction *I = dyn_cast<Instruction>(V)) {
7028 // Don't attempt to analyze instructions in blocks that aren't
7029 // reachable. Such instructions don't matter, and they aren't required
7030 // to obey basic rules for definitions dominating uses which this
7031 // analysis depends on.
7032 if (!DT.isReachableFromEntry(I->getParent()))
7033 return getUnknown(UndefValue::get(V->getType()));
7034 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7035 return getConstant(CI);
7036 else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
7037 return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
7038 else if (!isa<ConstantExpr>(V))
7039 return getUnknown(V);
7041 Operator *U = cast<Operator>(V);
7042 if (auto BO = MatchBinaryOp(U, DT)) {
7043 switch (BO->Opcode) {
7044 case Instruction::Add: {
7045 // The simple thing to do would be to just call getSCEV on both operands
7046 // and call getAddExpr with the result. However if we're looking at a
7047 // bunch of things all added together, this can be quite inefficient,
7048 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7049 // Instead, gather up all the operands and make a single getAddExpr call.
7050 // LLVM IR canonical form means we need only traverse the left operands.
7051 SmallVector<const SCEV *, 4> AddOps;
7054 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7055 AddOps.push_back(OpSCEV);
7059 // If a NUW or NSW flag can be applied to the SCEV for this
7060 // addition, then compute the SCEV for this addition by itself
7061 // with a separate call to getAddExpr. We need to do that
7062 // instead of pushing the operands of the addition onto AddOps,
7063 // since the flags are only known to apply to this particular
7064 // addition - they may not apply to other additions that can be
7065 // formed with operands from AddOps.
7066 const SCEV *RHS = getSCEV(BO->RHS);
7067 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7068 if (Flags != SCEV::FlagAnyWrap) {
7069 const SCEV *LHS = getSCEV(BO->LHS);
7070 if (BO->Opcode == Instruction::Sub)
7071 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7073 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7078 if (BO->Opcode == Instruction::Sub)
7079 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7081 AddOps.push_back(getSCEV(BO->RHS));
7083 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7084 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7085 NewBO->Opcode != Instruction::Sub)) {
7086 AddOps.push_back(getSCEV(BO->LHS));
7092 return getAddExpr(AddOps);
7095 case Instruction::Mul: {
7096 SmallVector<const SCEV *, 4> MulOps;
7099 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7100 MulOps.push_back(OpSCEV);
7104 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7105 if (Flags != SCEV::FlagAnyWrap) {
7107 getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags));
7112 MulOps.push_back(getSCEV(BO->RHS));
7113 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7114 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7115 MulOps.push_back(getSCEV(BO->LHS));
7121 return getMulExpr(MulOps);
7123 case Instruction::UDiv:
7124 return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
7125 case Instruction::URem:
7126 return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
7127 case Instruction::Sub: {
7128 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
7130 Flags = getNoWrapFlagsFromUB(BO->Op);
7131 return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags);
7133 case Instruction::And:
7134 // For an expression like x&255 that merely masks off the high bits,
7135 // use zext(trunc(x)) as the SCEV expression.
7136 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7138 return getSCEV(BO->RHS);
7139 if (CI->isMinusOne())
7140 return getSCEV(BO->LHS);
7141 const APInt &A = CI->getValue();
7143 // Instcombine's ShrinkDemandedConstant may strip bits out of
7144 // constants, obscuring what would otherwise be a low-bits mask.
7145 // Use computeKnownBits to compute what ShrinkDemandedConstant
7146 // knew about to reconstruct a low-bits mask value.
7147 unsigned LZ = A.countLeadingZeros();
7148 unsigned TZ = A.countTrailingZeros();
7149 unsigned BitWidth = A.getBitWidth();
7150 KnownBits Known(BitWidth);
7151 computeKnownBits(BO->LHS, Known, getDataLayout(),
7152 0, &AC, nullptr, &DT);
7154 APInt EffectiveMask =
7155 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7156 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7157 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7158 const SCEV *LHS = getSCEV(BO->LHS);
7159 const SCEV *ShiftedLHS = nullptr;
7160 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7161 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7162 // For an expression like (x * 8) & 8, simplify the multiply.
7163 unsigned MulZeros = OpC->getAPInt().countTrailingZeros();
7164 unsigned GCD = std::min(MulZeros, TZ);
7165 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7166 SmallVector<const SCEV*, 4> MulOps;
7167 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7168 MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end());
7169 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7170 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7174 ShiftedLHS = getUDivExpr(LHS, MulCount);
7177 getTruncateExpr(ShiftedLHS,
7178 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7179 BO->LHS->getType()),
7185 case Instruction::Or:
7186 // If the RHS of the Or is a constant, we may have something like:
7187 // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
7188 // optimizations will transparently handle this case.
7190 // In order for this transformation to be safe, the LHS must be of the
7191 // form X*(2^n) and the Or constant must be less than 2^n.
7192 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7193 const SCEV *LHS = getSCEV(BO->LHS);
7194 const APInt &CIVal = CI->getValue();
7195 if (GetMinTrailingZeros(LHS) >=
7196 (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
7197 // Build a plain add SCEV.
7198 return getAddExpr(LHS, getSCEV(CI),
7199 (SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW));
7204 case Instruction::Xor:
7205 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7206 // If the RHS of xor is -1, then this is a not operation.
7207 if (CI->isMinusOne())
7208 return getNotSCEV(getSCEV(BO->LHS));
7210 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7211 // This is a variant of the check for xor with -1, and it handles
7212 // the case where instcombine has trimmed non-demanded bits out
7213 // of an xor with -1.
7214 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7215 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7216 if (LBO->getOpcode() == Instruction::And &&
7217 LCI->getValue() == CI->getValue())
7218 if (const SCEVZeroExtendExpr *Z =
7219 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7220 Type *UTy = BO->LHS->getType();
7221 const SCEV *Z0 = Z->getOperand();
7222 Type *Z0Ty = Z0->getType();
7223 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7225 // If C is a low-bits mask, the zero extend is serving to
7226 // mask off the high bits. Complement the operand and
7227 // re-apply the zext.
7228 if (CI->getValue().isMask(Z0TySize))
7229 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7231 // If C is a single bit, it may be in the sign-bit position
7232 // before the zero-extend. In this case, represent the xor
7233 // using an add, which is equivalent, and re-apply the zext.
7234 APInt Trunc = CI->getValue().trunc(Z0TySize);
7235 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7237 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7243 case Instruction::Shl:
7244 // Turn shift left of a constant amount into a multiply.
7245 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7246 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7248 // If the shift count is not less than the bitwidth, the result of
7249 // the shift is undefined. Don't try to analyze it, because the
7250 // resolution chosen here may differ from the resolution chosen in
7251 // other parts of the compiler.
7252 if (SA->getValue().uge(BitWidth))
7255 // We can safely preserve the nuw flag in all cases. It's also safe to
7256 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7257 // requires special handling. It can be preserved as long as we're not
7258 // left shifting by bitwidth - 1.
7259 auto Flags = SCEV::FlagAnyWrap;
7261 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7262 if ((MulFlags & SCEV::FlagNSW) &&
7263 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7264 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW);
7265 if (MulFlags & SCEV::FlagNUW)
7266 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW);
7269 Constant *X = ConstantInt::get(
7270 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7271 return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
7275 case Instruction::AShr: {
7276 // AShr X, C, where C is a constant.
7277 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7281 Type *OuterTy = BO->LHS->getType();
7282 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
7283 // If the shift count is not less than the bitwidth, the result of
7284 // the shift is undefined. Don't try to analyze it, because the
7285 // resolution chosen here may differ from the resolution chosen in
7286 // other parts of the compiler.
7287 if (CI->getValue().uge(BitWidth))
7291 return getSCEV(BO->LHS); // shift by zero --> noop
7293 uint64_t AShrAmt = CI->getZExtValue();
7294 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7296 Operator *L = dyn_cast<Operator>(BO->LHS);
7297 if (L && L->getOpcode() == Instruction::Shl) {
7300 // Both n and m are constant.
7302 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7303 if (L->getOperand(1) == BO->RHS)
7304 // For a two-shift sext-inreg, i.e. n = m,
7305 // use sext(trunc(x)) as the SCEV expression.
7306 return getSignExtendExpr(
7307 getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
7309 ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7310 if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
7311 uint64_t ShlAmt = ShlAmtCI->getZExtValue();
7312 if (ShlAmt > AShrAmt) {
7313 // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7314 // expression. We already checked that ShlAmt < BitWidth, so
7315 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7316 // ShlAmt - AShrAmt < Amt.
7317 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
7319 return getSignExtendExpr(
7320 getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
7321 getConstant(Mul)), OuterTy);
7330 switch (U->getOpcode()) {
7331 case Instruction::Trunc:
7332 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
7334 case Instruction::ZExt:
7335 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7337 case Instruction::SExt:
7338 if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
7339 // The NSW flag of a subtract does not always survive the conversion to
7340 // A + (-1)*B. By pushing sign extension onto its operands we are much
7341 // more likely to preserve NSW and allow later AddRec optimisations.
7343 // NOTE: This is effectively duplicating this logic from getSignExtend:
7344 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
7345 // but by that point the NSW information has potentially been lost.
7346 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
7347 Type *Ty = U->getType();
7348 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
7349 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
7350 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
7353 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7355 case Instruction::BitCast:
7356 // BitCasts are no-op casts so we just eliminate the cast.
7357 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
7358 return getSCEV(U->getOperand(0));
7361 case Instruction::PtrToInt: {
7362 // Pointer to integer cast is straight-forward, so do model it.
7363 const SCEV *Op = getSCEV(U->getOperand(0));
7364 Type *DstIntTy = U->getType();
7365 // But only if effective SCEV (integer) type is wide enough to represent
7366 // all possible pointer values.
7367 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
7368 if (isa<SCEVCouldNotCompute>(IntOp))
7369 return getUnknown(V);
7372 case Instruction::IntToPtr:
7373 // Just don't deal with inttoptr casts.
7374 return getUnknown(V);
7376 case Instruction::SDiv:
7377 // If both operands are non-negative, this is just an udiv.
7378 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7379 isKnownNonNegative(getSCEV(U->getOperand(1))))
7380 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7383 case Instruction::SRem:
7384 // If both operands are non-negative, this is just an urem.
7385 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7386 isKnownNonNegative(getSCEV(U->getOperand(1))))
7387 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7390 case Instruction::GetElementPtr:
7391 return createNodeForGEP(cast<GEPOperator>(U));
7393 case Instruction::PHI:
7394 return createNodeForPHI(cast<PHINode>(U));
7396 case Instruction::Select:
7397 // U can also be a select constant expr, which let fall through. Since
7398 // createNodeForSelect only works for a condition that is an `ICmpInst`, and
7399 // constant expressions cannot have instructions as operands, we'd have
7400 // returned getUnknown for a select constant expressions anyway.
7401 if (isa<Instruction>(U))
7402 return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0),
7403 U->getOperand(1), U->getOperand(2));
7406 case Instruction::Call:
7407 case Instruction::Invoke:
7408 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
7411 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7412 switch (II->getIntrinsicID()) {
7413 case Intrinsic::abs:
7415 getSCEV(II->getArgOperand(0)),
7416 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
7417 case Intrinsic::umax:
7418 return getUMaxExpr(getSCEV(II->getArgOperand(0)),
7419 getSCEV(II->getArgOperand(1)));
7420 case Intrinsic::umin:
7421 return getUMinExpr(getSCEV(II->getArgOperand(0)),
7422 getSCEV(II->getArgOperand(1)));
7423 case Intrinsic::smax:
7424 return getSMaxExpr(getSCEV(II->getArgOperand(0)),
7425 getSCEV(II->getArgOperand(1)));
7426 case Intrinsic::smin:
7427 return getSMinExpr(getSCEV(II->getArgOperand(0)),
7428 getSCEV(II->getArgOperand(1)));
7429 case Intrinsic::usub_sat: {
7430 const SCEV *X = getSCEV(II->getArgOperand(0));
7431 const SCEV *Y = getSCEV(II->getArgOperand(1));
7432 const SCEV *ClampedY = getUMinExpr(X, Y);
7433 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
7435 case Intrinsic::uadd_sat: {
7436 const SCEV *X = getSCEV(II->getArgOperand(0));
7437 const SCEV *Y = getSCEV(II->getArgOperand(1));
7438 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
7439 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
7441 case Intrinsic::start_loop_iterations:
7442 // A start_loop_iterations is just equivalent to the first operand for
7444 return getSCEV(II->getArgOperand(0));
7452 return getUnknown(V);
7455 //===----------------------------------------------------------------------===//
7456 // Iteration Count Computation Code
7459 const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
7461 if (isa<SCEVCouldNotCompute>(ExitCount))
7462 return getCouldNotCompute();
7464 auto *ExitCountType = ExitCount->getType();
7465 assert(ExitCountType->isIntegerTy());
7468 return getAddExpr(ExitCount, getOne(ExitCountType));
7470 auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
7471 1 + ExitCountType->getScalarSizeInBits());
7472 return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
7476 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
7480 ConstantInt *ExitConst = ExitCount->getValue();
7482 // Guard against huge trip counts.
7483 if (ExitConst->getValue().getActiveBits() > 32)
7486 // In case of integer overflow, this returns 0, which is correct.
7487 return ((unsigned)ExitConst->getZExtValue()) + 1;
7490 unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
7491 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
7492 return getConstantTripCount(ExitCount);
7496 ScalarEvolution::getSmallConstantTripCount(const Loop *L,
7497 const BasicBlock *ExitingBlock) {
7498 assert(ExitingBlock && "Must pass a non-null exiting block!");
7499 assert(L->isLoopExiting(ExitingBlock) &&
7500 "Exiting block must actually branch out of the loop!");
7501 const SCEVConstant *ExitCount =
7502 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
7503 return getConstantTripCount(ExitCount);
7506 unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
7507 const auto *MaxExitCount =
7508 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
7509 return getConstantTripCount(MaxExitCount);
7512 const SCEV *ScalarEvolution::getConstantMaxTripCountFromArray(const Loop *L) {
7513 // We can't infer from Array in Irregular Loop.
7514 // FIXME: It's hard to infer loop bound from array operated in Nested Loop.
7515 if (!L->isLoopSimplifyForm() || !L->isInnermost())
7516 return getCouldNotCompute();
7518 // FIXME: To make the scene more typical, we only analysis loops that have
7519 // one exiting block and that block must be the latch. To make it easier to
7520 // capture loops that have memory access and memory access will be executed
7521 // in each iteration.
7522 const BasicBlock *LoopLatch = L->getLoopLatch();
7523 assert(LoopLatch && "See defination of simplify form loop.");
7524 if (L->getExitingBlock() != LoopLatch)
7525 return getCouldNotCompute();
7527 const DataLayout &DL = getDataLayout();
7528 SmallVector<const SCEV *> InferCountColl;
7529 for (auto *BB : L->getBlocks()) {
7530 // Go here, we can know that Loop is a single exiting and simplified form
7531 // loop. Make sure that infer from Memory Operation in those BBs must be
7532 // executed in loop. First step, we can make sure that max execution time
7533 // of MemAccessBB in loop represents latch max excution time.
7534 // If MemAccessBB does not dom Latch, skip.
7538 // │Loop Header◄─────┐
7541 // ┌────────▼──┐ ┌─▼─────┐ │
7542 // │MemAccessBB│ │OtherBB│ │
7543 // └────────┬──┘ └─┬─────┘ │
7546 // │Loop Latch├─────┘
7550 if (!DT.dominates(BB, LoopLatch))
7553 for (Instruction &Inst : *BB) {
7554 // Find Memory Operation Instruction.
7555 auto *GEP = getLoadStorePointerOperand(&Inst);
7559 auto *ElemSize = dyn_cast<SCEVConstant>(getElementSize(&Inst));
7560 // Do not infer from scalar type, eg."ElemSize = sizeof()".
7564 // Use a existing polynomial recurrence on the trip count.
7565 auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(GEP));
7568 auto *ArrBase = dyn_cast<SCEVUnknown>(getPointerBase(AddRec));
7569 auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*this));
7570 if (!ArrBase || !Step)
7572 assert(isLoopInvariant(ArrBase, L) && "See addrec definition");
7574 // Only handle { %array + step },
7575 // FIXME: {(SCEVAddRecExpr) + step } could not be analysed here.
7576 if (AddRec->getStart() != ArrBase)
7579 // Memory operation pattern which have gaps.
7580 // Or repeat memory opreation.
7581 // And index of GEP wraps arround.
7582 if (Step->getAPInt().getActiveBits() > 32 ||
7583 Step->getAPInt().getZExtValue() !=
7584 ElemSize->getAPInt().getZExtValue() ||
7585 Step->isZero() || Step->getAPInt().isNegative())
7588 // Only infer from stack array which has certain size.
7589 // Make sure alloca instruction is not excuted in loop.
7590 AllocaInst *AllocateInst = dyn_cast<AllocaInst>(ArrBase->getValue());
7591 if (!AllocateInst || L->contains(AllocateInst->getParent()))
7594 // Make sure only handle normal array.
7595 auto *Ty = dyn_cast<ArrayType>(AllocateInst->getAllocatedType());
7596 auto *ArrSize = dyn_cast<ConstantInt>(AllocateInst->getArraySize());
7597 if (!Ty || !ArrSize || !ArrSize->isOne())
7600 // FIXME: Since gep indices are silently zext to the indexing type,
7601 // we will have a narrow gep index which wraps around rather than
7602 // increasing strictly, we shoule ensure that step is increasing
7603 // strictly by the loop iteration.
7604 // Now we can infer a max execution time by MemLength/StepLength.
7605 const SCEV *MemSize =
7606 getConstant(Step->getType(), DL.getTypeAllocSize(Ty));
7608 dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
7609 if (!MaxExeCount || MaxExeCount->getAPInt().getActiveBits() > 32)
7612 // If the loop reaches the maximum number of executions, we can not
7613 // access bytes starting outside the statically allocated size without
7614 // being immediate UB. But it is allowed to enter loop header one more
7616 auto *InferCount = dyn_cast<SCEVConstant>(
7617 getAddExpr(MaxExeCount, getOne(MaxExeCount->getType())));
7618 // Discard the maximum number of execution times under 32bits.
7619 if (!InferCount || InferCount->getAPInt().getActiveBits() > 32)
7622 InferCountColl.push_back(InferCount);
7626 if (InferCountColl.size() == 0)
7627 return getCouldNotCompute();
7629 return getUMinFromMismatchedTypes(InferCountColl);
7632 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
7633 SmallVector<BasicBlock *, 8> ExitingBlocks;
7634 L->getExitingBlocks(ExitingBlocks);
7636 Optional<unsigned> Res = None;
7637 for (auto *ExitingBB : ExitingBlocks) {
7638 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
7641 Res = (unsigned)GreatestCommonDivisor64(*Res, Multiple);
7643 return Res.getValueOr(1);
7646 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
7647 const SCEV *ExitCount) {
7648 if (ExitCount == getCouldNotCompute())
7651 // Get the trip count
7652 const SCEV *TCExpr = getTripCountFromExitCount(ExitCount);
7654 const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
7656 // Attempt to factor more general cases. Returns the greatest power of
7657 // two divisor. If overflow happens, the trip count expression is still
7658 // divisible by the greatest power of 2 divisor returned.
7659 return 1U << std::min((uint32_t)31,
7660 GetMinTrailingZeros(applyLoopGuards(TCExpr, L)));
7662 ConstantInt *Result = TC->getValue();
7664 // Guard against huge trip counts (this requires checking
7665 // for zero to handle the case where the trip count == -1 and the
7667 if (!Result || Result->getValue().getActiveBits() > 32 ||
7668 Result->getValue().getActiveBits() == 0)
7671 return (unsigned)Result->getZExtValue();
7674 /// Returns the largest constant divisor of the trip count of this loop as a
7675 /// normal unsigned value, if possible. This means that the actual trip count is
7676 /// always a multiple of the returned value (don't forget the trip count could
7677 /// very well be zero as well!).
7679 /// Returns 1 if the trip count is unknown or not guaranteed to be the
7680 /// multiple of a constant (which is also the case if the trip count is simply
7681 /// constant, use getSmallConstantTripCount for that case), Will also return 1
7682 /// if the trip count is very large (>= 2^32).
7684 /// As explained in the comments for getSmallConstantTripCount, this assumes
7685 /// that control exits the loop via ExitingBlock.
7687 ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
7688 const BasicBlock *ExitingBlock) {
7689 assert(ExitingBlock && "Must pass a non-null exiting block!");
7690 assert(L->isLoopExiting(ExitingBlock) &&
7691 "Exiting block must actually branch out of the loop!");
7692 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
7693 return getSmallConstantTripMultiple(L, ExitCount);
7696 const SCEV *ScalarEvolution::getExitCount(const Loop *L,
7697 const BasicBlock *ExitingBlock,
7698 ExitCountKind Kind) {
7701 case SymbolicMaximum:
7702 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
7703 case ConstantMaximum:
7704 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
7706 llvm_unreachable("Invalid ExitCountKind!");
7710 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
7711 SCEVUnionPredicate &Preds) {
7712 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
7715 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
7716 ExitCountKind Kind) {
7719 return getBackedgeTakenInfo(L).getExact(L, this);
7720 case ConstantMaximum:
7721 return getBackedgeTakenInfo(L).getConstantMax(this);
7722 case SymbolicMaximum:
7723 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
7725 llvm_unreachable("Invalid ExitCountKind!");
7728 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
7729 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
7732 /// Push PHI nodes in the header of the given loop onto the given Worklist.
7733 static void PushLoopPHIs(const Loop *L,
7734 SmallVectorImpl<Instruction *> &Worklist,
7735 SmallPtrSetImpl<Instruction *> &Visited) {
7736 BasicBlock *Header = L->getHeader();
7738 // Push all Loop-header PHIs onto the Worklist stack.
7739 for (PHINode &PN : Header->phis())
7740 if (Visited.insert(&PN).second)
7741 Worklist.push_back(&PN);
7744 const ScalarEvolution::BackedgeTakenInfo &
7745 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
7746 auto &BTI = getBackedgeTakenInfo(L);
7747 if (BTI.hasFullInfo())
7750 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
7753 return Pair.first->second;
7755 BackedgeTakenInfo Result =
7756 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
7758 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
7761 ScalarEvolution::BackedgeTakenInfo &
7762 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
7763 // Initially insert an invalid entry for this loop. If the insertion
7764 // succeeds, proceed to actually compute a backedge-taken count and
7765 // update the value. The temporary CouldNotCompute value tells SCEV
7766 // code elsewhere that it shouldn't attempt to request a new
7767 // backedge-taken count, which could result in infinite recursion.
7768 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
7769 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
7771 return Pair.first->second;
7773 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
7774 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
7775 // must be cleared in this scope.
7776 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
7778 // In product build, there are no usage of statistic.
7779 (void)NumTripCountsComputed;
7780 (void)NumTripCountsNotComputed;
7781 #if LLVM_ENABLE_STATS || !defined(NDEBUG)
7782 const SCEV *BEExact = Result.getExact(L, this);
7783 if (BEExact != getCouldNotCompute()) {
7784 assert(isLoopInvariant(BEExact, L) &&
7785 isLoopInvariant(Result.getConstantMax(this), L) &&
7786 "Computed backedge-taken count isn't loop invariant for loop!");
7787 ++NumTripCountsComputed;
7788 } else if (Result.getConstantMax(this) == getCouldNotCompute() &&
7789 isa<PHINode>(L->getHeader()->begin())) {
7790 // Only count loops that have phi nodes as not being computable.
7791 ++NumTripCountsNotComputed;
7793 #endif // LLVM_ENABLE_STATS || !defined(NDEBUG)
7795 // Now that we know more about the trip count for this loop, forget any
7796 // existing SCEV values for PHI nodes in this loop since they are only
7797 // conservative estimates made without the benefit of trip count
7798 // information. This invalidation is not necessary for correctness, and is
7799 // only done to produce more precise results.
7800 if (Result.hasAnyInfo()) {
7801 // Invalidate any expression using an addrec in this loop.
7802 SmallVector<const SCEV *, 8> ToForget;
7803 auto LoopUsersIt = LoopUsers.find(L);
7804 if (LoopUsersIt != LoopUsers.end())
7805 append_range(ToForget, LoopUsersIt->second);
7806 forgetMemoizedResults(ToForget);
7808 // Invalidate constant-evolved loop header phis.
7809 for (PHINode &PN : L->getHeader()->phis())
7810 ConstantEvolutionLoopExitValue.erase(&PN);
7813 // Re-lookup the insert position, since the call to
7814 // computeBackedgeTakenCount above could result in a
7815 // recusive call to getBackedgeTakenInfo (on a different
7816 // loop), which would invalidate the iterator computed
7818 return BackedgeTakenCounts.find(L)->second = std::move(Result);
7821 void ScalarEvolution::forgetAllLoops() {
7822 // This method is intended to forget all info about loops. It should
7823 // invalidate caches as if the following happened:
7824 // - The trip counts of all loops have changed arbitrarily
7825 // - Every llvm::Value has been updated in place to produce a different
7827 BackedgeTakenCounts.clear();
7828 PredicatedBackedgeTakenCounts.clear();
7829 BECountUsers.clear();
7830 LoopPropertiesCache.clear();
7831 ConstantEvolutionLoopExitValue.clear();
7832 ValueExprMap.clear();
7833 ValuesAtScopes.clear();
7834 ValuesAtScopesUsers.clear();
7835 LoopDispositions.clear();
7836 BlockDispositions.clear();
7837 UnsignedRanges.clear();
7838 SignedRanges.clear();
7839 ExprValueMap.clear();
7841 MinTrailingZerosCache.clear();
7842 PredicatedSCEVRewrites.clear();
7845 void ScalarEvolution::forgetLoop(const Loop *L) {
7846 SmallVector<const Loop *, 16> LoopWorklist(1, L);
7847 SmallVector<Instruction *, 32> Worklist;
7848 SmallPtrSet<Instruction *, 16> Visited;
7849 SmallVector<const SCEV *, 16> ToForget;
7851 // Iterate over all the loops and sub-loops to drop SCEV information.
7852 while (!LoopWorklist.empty()) {
7853 auto *CurrL = LoopWorklist.pop_back_val();
7855 // Drop any stored trip count value.
7856 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
7857 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
7859 // Drop information about predicated SCEV rewrites for this loop.
7860 for (auto I = PredicatedSCEVRewrites.begin();
7861 I != PredicatedSCEVRewrites.end();) {
7862 std::pair<const SCEV *, const Loop *> Entry = I->first;
7863 if (Entry.second == CurrL)
7864 PredicatedSCEVRewrites.erase(I++);
7869 auto LoopUsersItr = LoopUsers.find(CurrL);
7870 if (LoopUsersItr != LoopUsers.end()) {
7871 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
7872 LoopUsersItr->second.end());
7873 LoopUsers.erase(LoopUsersItr);
7876 // Drop information about expressions based on loop-header PHIs.
7877 PushLoopPHIs(CurrL, Worklist, Visited);
7879 while (!Worklist.empty()) {
7880 Instruction *I = Worklist.pop_back_val();
7882 ValueExprMapType::iterator It =
7883 ValueExprMap.find_as(static_cast<Value *>(I));
7884 if (It != ValueExprMap.end()) {
7885 eraseValueFromMap(It->first);
7886 ToForget.push_back(It->second);
7887 if (PHINode *PN = dyn_cast<PHINode>(I))
7888 ConstantEvolutionLoopExitValue.erase(PN);
7891 PushDefUseChildren(I, Worklist, Visited);
7894 LoopPropertiesCache.erase(CurrL);
7895 // Forget all contained loops too, to avoid dangling entries in the
7896 // ValuesAtScopes map.
7897 LoopWorklist.append(CurrL->begin(), CurrL->end());
7899 forgetMemoizedResults(ToForget);
7902 void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
7903 while (Loop *Parent = L->getParentLoop())
7908 void ScalarEvolution::forgetValue(Value *V) {
7909 Instruction *I = dyn_cast<Instruction>(V);
7912 // Drop information about expressions based on loop-header PHIs.
7913 SmallVector<Instruction *, 16> Worklist;
7914 SmallPtrSet<Instruction *, 8> Visited;
7915 SmallVector<const SCEV *, 8> ToForget;
7916 Worklist.push_back(I);
7919 while (!Worklist.empty()) {
7920 I = Worklist.pop_back_val();
7921 ValueExprMapType::iterator It =
7922 ValueExprMap.find_as(static_cast<Value *>(I));
7923 if (It != ValueExprMap.end()) {
7924 eraseValueFromMap(It->first);
7925 ToForget.push_back(It->second);
7926 if (PHINode *PN = dyn_cast<PHINode>(I))
7927 ConstantEvolutionLoopExitValue.erase(PN);
7930 PushDefUseChildren(I, Worklist, Visited);
7932 forgetMemoizedResults(ToForget);
7935 void ScalarEvolution::forgetLoopDispositions(const Loop *L) {
7936 LoopDispositions.clear();
7939 /// Get the exact loop backedge taken count considering all loop exits. A
7940 /// computable result can only be returned for loops with all exiting blocks
7941 /// dominating the latch. howFarToZero assumes that the limit of each loop test
7942 /// is never skipped. This is a valid assumption as long as the loop exits via
7943 /// that test. For precise results, it is the caller's responsibility to specify
7944 /// the relevant loop exiting block using getExact(ExitingBlock, SE).
7946 ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
7947 SCEVUnionPredicate *Preds) const {
7948 // If any exits were not computable, the loop is not computable.
7949 if (!isComplete() || ExitNotTaken.empty())
7950 return SE->getCouldNotCompute();
7952 const BasicBlock *Latch = L->getLoopLatch();
7953 // All exiting blocks we have collected must dominate the only backedge.
7955 return SE->getCouldNotCompute();
7957 // All exiting blocks we have gathered dominate loop's latch, so exact trip
7958 // count is simply a minimum out of all these calculated exit counts.
7959 SmallVector<const SCEV *, 2> Ops;
7960 for (auto &ENT : ExitNotTaken) {
7961 const SCEV *BECount = ENT.ExactNotTaken;
7962 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
7963 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
7964 "We should only have known counts for exiting blocks that dominate "
7967 Ops.push_back(BECount);
7969 if (Preds && !ENT.hasAlwaysTruePredicate())
7970 Preds->add(ENT.Predicate.get());
7972 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
7973 "Predicate should be always true!");
7976 return SE->getUMinFromMismatchedTypes(Ops);
7979 /// Get the exact not taken count for this loop exit.
7981 ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
7982 ScalarEvolution *SE) const {
7983 for (auto &ENT : ExitNotTaken)
7984 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
7985 return ENT.ExactNotTaken;
7987 return SE->getCouldNotCompute();
7990 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
7991 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
7992 for (auto &ENT : ExitNotTaken)
7993 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
7994 return ENT.MaxNotTaken;
7996 return SE->getCouldNotCompute();
7999 /// getConstantMax - Get the constant max backedge taken count for the loop.
8001 ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8002 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8003 return !ENT.hasAlwaysTruePredicate();
8006 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8007 return SE->getCouldNotCompute();
8009 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8010 isa<SCEVConstant>(getConstantMax())) &&
8011 "No point in having a non-constant max backedge taken count!");
8012 return getConstantMax();
8016 ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
8017 ScalarEvolution *SE) {
8019 SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
8023 bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8024 ScalarEvolution *SE) const {
8025 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8026 return !ENT.hasAlwaysTruePredicate();
8028 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8031 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
8032 : ExitLimit(E, E, false, None) {
8035 ScalarEvolution::ExitLimit::ExitLimit(
8036 const SCEV *E, const SCEV *M, bool MaxOrZero,
8037 ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
8038 : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
8039 // If we prove the max count is zero, so is the symbolic bound. This happens
8040 // in practice due to differences in a) how context sensitive we've chosen
8041 // to be and b) how we reason about bounds impied by UB.
8042 if (MaxNotTaken->isZero())
8043 ExactNotTaken = MaxNotTaken;
8045 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8046 !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
8047 "Exact is not allowed to be less precise than Max");
8048 assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
8049 isa<SCEVConstant>(MaxNotTaken)) &&
8050 "No point in having a non-constant max backedge taken count!");
8051 for (auto *PredSet : PredSetList)
8052 for (auto *P : *PredSet)
8054 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8055 "Backedge count should be int");
8056 assert((isa<SCEVCouldNotCompute>(M) || !M->getType()->isPointerTy()) &&
8057 "Max backedge count should be int");
8060 ScalarEvolution::ExitLimit::ExitLimit(
8061 const SCEV *E, const SCEV *M, bool MaxOrZero,
8062 const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
8063 : ExitLimit(E, M, MaxOrZero, {&PredSet}) {
8066 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
8068 : ExitLimit(E, M, MaxOrZero, None) {
8071 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8072 /// computable exit into a persistent ExitNotTakenInfo array.
8073 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8074 ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts,
8075 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8076 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8077 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8079 ExitNotTaken.reserve(ExitCounts.size());
8081 ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
8082 [&](const EdgeExitInfo &EEI) {
8083 BasicBlock *ExitBB = EEI.first;
8084 const ExitLimit &EL = EEI.second;
8085 if (EL.Predicates.empty())
8086 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
8089 std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
8090 for (auto *Pred : EL.Predicates)
8091 Predicate->add(Pred);
8093 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
8094 std::move(Predicate));
8096 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8097 isa<SCEVConstant>(ConstantMax)) &&
8098 "No point in having a non-constant max backedge taken count!");
8101 /// Compute the number of times the backedge of the specified loop will execute.
8102 ScalarEvolution::BackedgeTakenInfo
8103 ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8104 bool AllowPredicates) {
8105 SmallVector<BasicBlock *, 8> ExitingBlocks;
8106 L->getExitingBlocks(ExitingBlocks);
8108 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8110 SmallVector<EdgeExitInfo, 4> ExitCounts;
8111 bool CouldComputeBECount = true;
8112 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8113 const SCEV *MustExitMaxBECount = nullptr;
8114 const SCEV *MayExitMaxBECount = nullptr;
8115 bool MustExitMaxOrZero = false;
8117 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8118 // and compute maxBECount.
8119 // Do a union of all the predicates here.
8120 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
8121 BasicBlock *ExitBB = ExitingBlocks[i];
8123 // We canonicalize untaken exits to br (constant), ignore them so that
8124 // proving an exit untaken doesn't negatively impact our ability to reason
8125 // about the loop as whole.
8126 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8127 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8128 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8129 if (ExitIfTrue == CI->isZero())
8133 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
8135 assert((AllowPredicates || EL.Predicates.empty()) &&
8136 "Predicated exit limit when predicates are not allowed!");
8138 // 1. For each exit that can be computed, add an entry to ExitCounts.
8139 // CouldComputeBECount is true only if all exits can be computed.
8140 if (EL.ExactNotTaken == getCouldNotCompute())
8141 // We couldn't compute an exact value for this exit, so
8142 // we won't be able to compute an exact value for the loop.
8143 CouldComputeBECount = false;
8145 ExitCounts.emplace_back(ExitBB, EL);
8147 // 2. Derive the loop's MaxBECount from each exit's max number of
8148 // non-exiting iterations. Partition the loop exits into two kinds:
8149 // LoopMustExits and LoopMayExits.
8151 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8152 // is a LoopMayExit. If any computable LoopMustExit is found, then
8153 // MaxBECount is the minimum EL.MaxNotTaken of computable
8154 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8155 // EL.MaxNotTaken, where CouldNotCompute is considered greater than any
8156 // computable EL.MaxNotTaken.
8157 if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
8158 DT.dominates(ExitBB, Latch)) {
8159 if (!MustExitMaxBECount) {
8160 MustExitMaxBECount = EL.MaxNotTaken;
8161 MustExitMaxOrZero = EL.MaxOrZero;
8163 MustExitMaxBECount =
8164 getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
8166 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8167 if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
8168 MayExitMaxBECount = EL.MaxNotTaken;
8171 getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
8175 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8176 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8177 // The loop backedge will be taken the maximum or zero times if there's
8178 // a single exit that must be taken the maximum or zero times.
8179 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8181 // Remember which SCEVs are used in exit limits for invalidation purposes.
8182 // We only care about non-constant SCEVs here, so we can ignore EL.MaxNotTaken
8183 // and MaxBECount, which must be SCEVConstant.
8184 for (const auto &Pair : ExitCounts)
8185 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8186 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8187 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8188 MaxBECount, MaxOrZero);
8191 ScalarEvolution::ExitLimit
8192 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8193 bool AllowPredicates) {
8194 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8195 // If our exiting block does not dominate the latch, then its connection with
8196 // loop's exit limit may be far from trivial.
8197 const BasicBlock *Latch = L->getLoopLatch();
8198 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8199 return getCouldNotCompute();
8201 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
8202 Instruction *Term = ExitingBlock->getTerminator();
8203 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8204 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8205 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8206 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8207 "It should have one successor in loop and one exit block!");
8208 // Proceed to the next level to examine the exit condition expression.
8209 return computeExitLimitFromCond(
8210 L, BI->getCondition(), ExitIfTrue,
8211 /*ControlsExit=*/IsOnlyExit, AllowPredicates);
8214 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8215 // For switch, make sure that there is a single exit from the loop.
8216 BasicBlock *Exit = nullptr;
8217 for (auto *SBB : successors(ExitingBlock))
8218 if (!L->contains(SBB)) {
8219 if (Exit) // Multiple exit successors.
8220 return getCouldNotCompute();
8223 assert(Exit && "Exiting block must have at least one exit");
8224 return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
8225 /*ControlsExit=*/IsOnlyExit);
8228 return getCouldNotCompute();
8231 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
8232 const Loop *L, Value *ExitCond, bool ExitIfTrue,
8233 bool ControlsExit, bool AllowPredicates) {
8234 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8235 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8236 ControlsExit, AllowPredicates);
8239 Optional<ScalarEvolution::ExitLimit>
8240 ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8241 bool ExitIfTrue, bool ControlsExit,
8242 bool AllowPredicates) {
8244 (void)this->ExitIfTrue;
8245 (void)this->AllowPredicates;
8247 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8248 this->AllowPredicates == AllowPredicates &&
8249 "Variance in assumed invariant key components!");
8250 auto Itr = TripCountMap.find({ExitCond, ControlsExit});
8251 if (Itr == TripCountMap.end())
8256 void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8259 bool AllowPredicates,
8260 const ExitLimit &EL) {
8261 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8262 this->AllowPredicates == AllowPredicates &&
8263 "Variance in assumed invariant key components!");
8265 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL});
8266 assert(InsertResult.second && "Expected successful insertion!");
8271 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8272 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8273 bool ControlsExit, bool AllowPredicates) {
8276 Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
8279 ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
8280 ControlsExit, AllowPredicates);
8281 Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
8285 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8286 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8287 bool ControlsExit, bool AllowPredicates) {
8288 // Handle BinOp conditions (And, Or).
8289 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8290 Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
8291 return *LimitFromBinOp;
8293 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8294 // Proceed to the next level to examine the icmp.
8295 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8297 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
8298 if (EL.hasFullInfo() || !AllowPredicates)
8301 // Try again, but use SCEV predicates this time.
8302 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
8303 /*AllowPredicates=*/true);
8306 // Check for a constant condition. These are normally stripped out by
8307 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8308 // preserve the CFG and is temporarily leaving constant conditions
8310 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8311 if (ExitIfTrue == !CI->getZExtValue())
8312 // The backedge is always taken.
8313 return getCouldNotCompute();
8315 // The backedge is never taken.
8316 return getZero(CI->getType());
8319 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8320 // with a constant step, we can form an equivalent icmp predicate and figure
8321 // out how many iterations will be taken before we exit.
8322 const WithOverflowInst *WO;
8324 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8325 match(WO->getRHS(), m_APInt(C))) {
8327 ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
8328 WO->getNoWrapKind());
8329 CmpInst::Predicate Pred;
8330 APInt NewRHSC, Offset;
8331 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8333 Pred = ICmpInst::getInversePredicate(Pred);
8334 auto *LHS = getSCEV(WO->getLHS());
8336 LHS = getAddExpr(LHS, getConstant(Offset));
8337 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8338 ControlsExit, AllowPredicates);
8339 if (EL.hasAnyInfo()) return EL;
8342 // If it's not an integer or pointer comparison then compute it the hard way.
8343 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8346 Optional<ScalarEvolution::ExitLimit>
8347 ScalarEvolution::computeExitLimitFromCondFromBinOp(
8348 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8349 bool ControlsExit, bool AllowPredicates) {
8350 // Check if the controlling expression for this loop is an And or Or.
8353 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
8355 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
8360 // EitherMayExit is true in these two cases:
8361 // br (and Op0 Op1), loop, exit
8362 // br (or Op0 Op1), exit, loop
8363 bool EitherMayExit = IsAnd ^ ExitIfTrue;
8364 ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue,
8365 ControlsExit && !EitherMayExit,
8367 ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue,
8368 ControlsExit && !EitherMayExit,
8371 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
8372 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
8373 if (isa<ConstantInt>(Op1))
8374 return Op1 == NeutralElement ? EL0 : EL1;
8375 if (isa<ConstantInt>(Op0))
8376 return Op0 == NeutralElement ? EL1 : EL0;
8378 const SCEV *BECount = getCouldNotCompute();
8379 const SCEV *MaxBECount = getCouldNotCompute();
8380 if (EitherMayExit) {
8381 // Both conditions must be same for the loop to continue executing.
8382 // Choose the less conservative count.
8383 if (EL0.ExactNotTaken != getCouldNotCompute() &&
8384 EL1.ExactNotTaken != getCouldNotCompute()) {
8385 BECount = getUMinFromMismatchedTypes(
8386 EL0.ExactNotTaken, EL1.ExactNotTaken,
8387 /*Sequential=*/!isa<BinaryOperator>(ExitCond));
8389 // If EL0.ExactNotTaken was zero and ExitCond was a short-circuit form,
8390 // it should have been simplified to zero (see the condition (3) above)
8391 assert(!isa<BinaryOperator>(ExitCond) || !EL0.ExactNotTaken->isZero() ||
8394 if (EL0.MaxNotTaken == getCouldNotCompute())
8395 MaxBECount = EL1.MaxNotTaken;
8396 else if (EL1.MaxNotTaken == getCouldNotCompute())
8397 MaxBECount = EL0.MaxNotTaken;
8399 MaxBECount = getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
8401 // Both conditions must be same at the same time for the loop to exit.
8402 // For now, be conservative.
8403 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
8404 BECount = EL0.ExactNotTaken;
8407 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
8408 // to be more aggressive when computing BECount than when computing
8409 // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
8410 // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
8412 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
8413 !isa<SCEVCouldNotCompute>(BECount))
8414 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
8416 return ExitLimit(BECount, MaxBECount, false,
8417 { &EL0.Predicates, &EL1.Predicates });
8420 ScalarEvolution::ExitLimit
8421 ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
8425 bool AllowPredicates) {
8426 // If the condition was exit on true, convert the condition to exit on false
8427 ICmpInst::Predicate Pred;
8429 Pred = ExitCond->getPredicate();
8431 Pred = ExitCond->getInversePredicate();
8432 const ICmpInst::Predicate OriginalPred = Pred;
8434 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
8435 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
8437 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsExit,
8439 if (EL.hasAnyInfo()) return EL;
8441 auto *ExhaustiveCount =
8442 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8444 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
8445 return ExhaustiveCount;
8447 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
8448 ExitCond->getOperand(1), L, OriginalPred);
8450 ScalarEvolution::ExitLimit
8451 ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
8452 ICmpInst::Predicate Pred,
8453 const SCEV *LHS, const SCEV *RHS,
8455 bool AllowPredicates) {
8457 // Try to evaluate any dependencies out of the loop.
8458 LHS = getSCEVAtScope(LHS, L);
8459 RHS = getSCEVAtScope(RHS, L);
8461 // At this point, we would like to compute how many iterations of the
8462 // loop the predicate will return true for these inputs.
8463 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
8464 // If there is a loop-invariant, force it into the RHS.
8465 std::swap(LHS, RHS);
8466 Pred = ICmpInst::getSwappedPredicate(Pred);
8469 bool ControllingFiniteLoop =
8470 ControlsExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L);
8471 // Simplify the operands before analyzing them.
8472 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0,
8473 ControllingFiniteLoop);
8475 // If we have a comparison of a chrec against a constant, try to use value
8476 // ranges to answer this query.
8477 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
8478 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
8479 if (AddRec->getLoop() == L) {
8480 // Form the constant range.
8481 ConstantRange CompRange =
8482 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
8484 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
8485 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
8488 // If this loop must exit based on this condition (or execute undefined
8489 // behaviour), and we can prove the test sequence produced must repeat
8490 // the same values on self-wrap of the IV, then we can infer that IV
8491 // doesn't self wrap because if it did, we'd have an infinite (undefined)
8493 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
8494 // TODO: We can peel off any functions which are invertible *in L*. Loop
8495 // invariant terms are effectively constants for our purposes here.
8496 auto *InnerLHS = LHS;
8497 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
8498 InnerLHS = ZExt->getOperand();
8499 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
8500 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
8501 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
8502 StrideC && StrideC->getAPInt().isPowerOf2()) {
8503 auto Flags = AR->getNoWrapFlags();
8504 Flags = setFlags(Flags, SCEV::FlagNW);
8505 SmallVector<const SCEV*> Operands{AR->operands()};
8506 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
8507 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
8513 case ICmpInst::ICMP_NE: { // while (X != Y)
8514 // Convert to: while (X-Y != 0)
8515 if (LHS->getType()->isPointerTy()) {
8516 LHS = getLosslessPtrToIntExpr(LHS);
8517 if (isa<SCEVCouldNotCompute>(LHS))
8520 if (RHS->getType()->isPointerTy()) {
8521 RHS = getLosslessPtrToIntExpr(RHS);
8522 if (isa<SCEVCouldNotCompute>(RHS))
8525 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
8527 if (EL.hasAnyInfo()) return EL;
8530 case ICmpInst::ICMP_EQ: { // while (X == Y)
8531 // Convert to: while (X-Y == 0)
8532 if (LHS->getType()->isPointerTy()) {
8533 LHS = getLosslessPtrToIntExpr(LHS);
8534 if (isa<SCEVCouldNotCompute>(LHS))
8537 if (RHS->getType()->isPointerTy()) {
8538 RHS = getLosslessPtrToIntExpr(RHS);
8539 if (isa<SCEVCouldNotCompute>(RHS))
8542 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
8543 if (EL.hasAnyInfo()) return EL;
8546 case ICmpInst::ICMP_SLT:
8547 case ICmpInst::ICMP_ULT: { // while (X < Y)
8548 bool IsSigned = Pred == ICmpInst::ICMP_SLT;
8549 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
8551 if (EL.hasAnyInfo()) return EL;
8554 case ICmpInst::ICMP_SGT:
8555 case ICmpInst::ICMP_UGT: { // while (X > Y)
8556 bool IsSigned = Pred == ICmpInst::ICMP_SGT;
8558 howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
8560 if (EL.hasAnyInfo()) return EL;
8567 return getCouldNotCompute();
8570 ScalarEvolution::ExitLimit
8571 ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
8573 BasicBlock *ExitingBlock,
8574 bool ControlsExit) {
8575 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
8577 // Give up if the exit is the default dest of a switch.
8578 if (Switch->getDefaultDest() == ExitingBlock)
8579 return getCouldNotCompute();
8581 assert(L->contains(Switch->getDefaultDest()) &&
8582 "Default case must not exit the loop!");
8583 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
8584 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
8586 // while (X != Y) --> while (X-Y != 0)
8587 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
8588 if (EL.hasAnyInfo())
8591 return getCouldNotCompute();
8594 static ConstantInt *
8595 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
8596 ScalarEvolution &SE) {
8597 const SCEV *InVal = SE.getConstant(C);
8598 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
8599 assert(isa<SCEVConstant>(Val) &&
8600 "Evaluation of SCEV at constant didn't fold correctly?");
8601 return cast<SCEVConstant>(Val)->getValue();
8604 ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
8605 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
8606 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
8608 return getCouldNotCompute();
8610 const BasicBlock *Latch = L->getLoopLatch();
8612 return getCouldNotCompute();
8614 const BasicBlock *Predecessor = L->getLoopPredecessor();
8616 return getCouldNotCompute();
8618 // Return true if V is of the form "LHS `shift_op` <positive constant>".
8619 // Return LHS in OutLHS and shift_opt in OutOpCode.
8620 auto MatchPositiveShift =
8621 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
8623 using namespace PatternMatch;
8625 ConstantInt *ShiftAmt;
8626 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
8627 OutOpCode = Instruction::LShr;
8628 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
8629 OutOpCode = Instruction::AShr;
8630 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
8631 OutOpCode = Instruction::Shl;
8635 return ShiftAmt->getValue().isStrictlyPositive();
8638 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
8641 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
8642 // %iv.shifted = lshr i32 %iv, <positive constant>
8644 // Return true on a successful match. Return the corresponding PHI node (%iv
8645 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
8646 auto MatchShiftRecurrence =
8647 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
8648 Optional<Instruction::BinaryOps> PostShiftOpCode;
8651 Instruction::BinaryOps OpC;
8654 // If we encounter a shift instruction, "peel off" the shift operation,
8655 // and remember that we did so. Later when we inspect %iv's backedge
8656 // value, we will make sure that the backedge value uses the same
8659 // Note: the peeled shift operation does not have to be the same
8660 // instruction as the one feeding into the PHI's backedge value. We only
8661 // really care about it being the same *kind* of shift instruction --
8662 // that's all that is required for our later inferences to hold.
8663 if (MatchPositiveShift(LHS, V, OpC)) {
8664 PostShiftOpCode = OpC;
8669 PNOut = dyn_cast<PHINode>(LHS);
8670 if (!PNOut || PNOut->getParent() != L->getHeader())
8673 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
8677 // The backedge value for the PHI node must be a shift by a positive
8679 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
8681 // of the PHI node itself
8684 // and the kind of shift should be match the kind of shift we peeled
8686 (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut);
8690 Instruction::BinaryOps OpCode;
8691 if (!MatchShiftRecurrence(LHS, PN, OpCode))
8692 return getCouldNotCompute();
8694 const DataLayout &DL = getDataLayout();
8696 // The key rationale for this optimization is that for some kinds of shift
8697 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
8698 // within a finite number of iterations. If the condition guarding the
8699 // backedge (in the sense that the backedge is taken if the condition is true)
8700 // is false for the value the shift recurrence stabilizes to, then we know
8701 // that the backedge is taken only a finite number of times.
8703 ConstantInt *StableValue = nullptr;
8706 llvm_unreachable("Impossible case!");
8708 case Instruction::AShr: {
8709 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
8710 // bitwidth(K) iterations.
8711 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
8712 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
8713 Predecessor->getTerminator(), &DT);
8714 auto *Ty = cast<IntegerType>(RHS->getType());
8715 if (Known.isNonNegative())
8716 StableValue = ConstantInt::get(Ty, 0);
8717 else if (Known.isNegative())
8718 StableValue = ConstantInt::get(Ty, -1, true);
8720 return getCouldNotCompute();
8724 case Instruction::LShr:
8725 case Instruction::Shl:
8726 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
8727 // stabilize to 0 in at most bitwidth(K) iterations.
8728 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
8733 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
8734 assert(Result->getType()->isIntegerTy(1) &&
8735 "Otherwise cannot be an operand to a branch instruction");
8737 if (Result->isZeroValue()) {
8738 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
8739 const SCEV *UpperBound =
8740 getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
8741 return ExitLimit(getCouldNotCompute(), UpperBound, false);
8744 return getCouldNotCompute();
8747 /// Return true if we can constant fold an instruction of the specified type,
8748 /// assuming that all operands were constants.
8749 static bool CanConstantFold(const Instruction *I) {
8750 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
8751 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
8752 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
8755 if (const CallInst *CI = dyn_cast<CallInst>(I))
8756 if (const Function *F = CI->getCalledFunction())
8757 return canConstantFoldCallTo(CI, F);
8761 /// Determine whether this instruction can constant evolve within this loop
8762 /// assuming its operands can all constant evolve.
8763 static bool canConstantEvolve(Instruction *I, const Loop *L) {
8764 // An instruction outside of the loop can't be derived from a loop PHI.
8765 if (!L->contains(I)) return false;
8767 if (isa<PHINode>(I)) {
8768 // We don't currently keep track of the control flow needed to evaluate
8769 // PHIs, so we cannot handle PHIs inside of loops.
8770 return L->getHeader() == I->getParent();
8773 // If we won't be able to constant fold this expression even if the operands
8774 // are constants, bail early.
8775 return CanConstantFold(I);
8778 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
8779 /// recursing through each instruction operand until reaching a loop header phi.
8781 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
8782 DenseMap<Instruction *, PHINode *> &PHIMap,
8784 if (Depth > MaxConstantEvolvingDepth)
8787 // Otherwise, we can evaluate this instruction if all of its operands are
8788 // constant or derived from a PHI node themselves.
8789 PHINode *PHI = nullptr;
8790 for (Value *Op : UseInst->operands()) {
8791 if (isa<Constant>(Op)) continue;
8793 Instruction *OpInst = dyn_cast<Instruction>(Op);
8794 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
8796 PHINode *P = dyn_cast<PHINode>(OpInst);
8798 // If this operand is already visited, reuse the prior result.
8799 // We may have P != PHI if this is the deepest point at which the
8800 // inconsistent paths meet.
8801 P = PHIMap.lookup(OpInst);
8803 // Recurse and memoize the results, whether a phi is found or not.
8804 // This recursive call invalidates pointers into PHIMap.
8805 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
8809 return nullptr; // Not evolving from PHI
8810 if (PHI && PHI != P)
8811 return nullptr; // Evolving from multiple different PHIs.
8814 // This is a expression evolving from a constant PHI!
8818 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
8819 /// in the loop that V is derived from. We allow arbitrary operations along the
8820 /// way, but the operands of an operation must either be constants or a value
8821 /// derived from a constant PHI. If this expression does not fit with these
8822 /// constraints, return null.
8823 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
8824 Instruction *I = dyn_cast<Instruction>(V);
8825 if (!I || !canConstantEvolve(I, L)) return nullptr;
8827 if (PHINode *PN = dyn_cast<PHINode>(I))
8830 // Record non-constant instructions contained by the loop.
8831 DenseMap<Instruction *, PHINode *> PHIMap;
8832 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
8835 /// EvaluateExpression - Given an expression that passes the
8836 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
8837 /// in the loop has the value PHIVal. If we can't fold this expression for some
8838 /// reason, return null.
8839 static Constant *EvaluateExpression(Value *V, const Loop *L,
8840 DenseMap<Instruction *, Constant *> &Vals,
8841 const DataLayout &DL,
8842 const TargetLibraryInfo *TLI) {
8843 // Convenient constant check, but redundant for recursive calls.
8844 if (Constant *C = dyn_cast<Constant>(V)) return C;
8845 Instruction *I = dyn_cast<Instruction>(V);
8846 if (!I) return nullptr;
8848 if (Constant *C = Vals.lookup(I)) return C;
8850 // An instruction inside the loop depends on a value outside the loop that we
8851 // weren't given a mapping for, or a value such as a call inside the loop.
8852 if (!canConstantEvolve(I, L)) return nullptr;
8854 // An unmapped PHI can be due to a branch or another loop inside this loop,
8855 // or due to this not being the initial iteration through a loop where we
8856 // couldn't compute the evolution of this particular PHI last time.
8857 if (isa<PHINode>(I)) return nullptr;
8859 std::vector<Constant*> Operands(I->getNumOperands());
8861 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
8862 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
8864 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
8865 if (!Operands[i]) return nullptr;
8868 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
8870 if (!C) return nullptr;
8874 if (CmpInst *CI = dyn_cast<CmpInst>(I))
8875 return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
8876 Operands[1], DL, TLI);
8877 if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
8878 if (!LI->isVolatile())
8879 return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
8881 return ConstantFoldInstOperands(I, Operands, DL, TLI);
8885 // If every incoming value to PN except the one for BB is a specific Constant,
8886 // return that, else return nullptr.
8887 static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
8888 Constant *IncomingVal = nullptr;
8890 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
8891 if (PN->getIncomingBlock(i) == BB)
8894 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
8898 if (IncomingVal != CurrentVal) {
8901 IncomingVal = CurrentVal;
8908 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
8909 /// in the header of its containing loop, we know the loop executes a
8910 /// constant number of times, and the PHI node is just a recurrence
8911 /// involving constants, fold it.
8913 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
8916 auto I = ConstantEvolutionLoopExitValue.find(PN);
8917 if (I != ConstantEvolutionLoopExitValue.end())
8920 if (BEs.ugt(MaxBruteForceIterations))
8921 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
8923 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
8925 DenseMap<Instruction *, Constant *> CurrentIterVals;
8926 BasicBlock *Header = L->getHeader();
8927 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
8929 BasicBlock *Latch = L->getLoopLatch();
8933 for (PHINode &PHI : Header->phis()) {
8934 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
8935 CurrentIterVals[&PHI] = StartCST;
8937 if (!CurrentIterVals.count(PN))
8938 return RetVal = nullptr;
8940 Value *BEValue = PN->getIncomingValueForBlock(Latch);
8942 // Execute the loop symbolically to determine the exit value.
8943 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
8944 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
8946 unsigned NumIterations = BEs.getZExtValue(); // must be in range
8947 unsigned IterationNum = 0;
8948 const DataLayout &DL = getDataLayout();
8949 for (; ; ++IterationNum) {
8950 if (IterationNum == NumIterations)
8951 return RetVal = CurrentIterVals[PN]; // Got exit value!
8953 // Compute the value of the PHIs for the next iteration.
8954 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
8955 DenseMap<Instruction *, Constant *> NextIterVals;
8957 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
8959 return nullptr; // Couldn't evaluate!
8960 NextIterVals[PN] = NextPHI;
8962 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
8964 // Also evaluate the other PHI nodes. However, we don't get to stop if we
8965 // cease to be able to evaluate one of them or if they stop evolving,
8966 // because that doesn't necessarily prevent us from computing PN.
8967 SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
8968 for (const auto &I : CurrentIterVals) {
8969 PHINode *PHI = dyn_cast<PHINode>(I.first);
8970 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
8971 PHIsToCompute.emplace_back(PHI, I.second);
8973 // We use two distinct loops because EvaluateExpression may invalidate any
8974 // iterators into CurrentIterVals.
8975 for (const auto &I : PHIsToCompute) {
8976 PHINode *PHI = I.first;
8977 Constant *&NextPHI = NextIterVals[PHI];
8978 if (!NextPHI) { // Not already computed.
8979 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
8980 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
8982 if (NextPHI != I.second)
8983 StoppedEvolving = false;
8986 // If all entries in CurrentIterVals == NextIterVals then we can stop
8987 // iterating, the loop can't continue to change.
8988 if (StoppedEvolving)
8989 return RetVal = CurrentIterVals[PN];
8991 CurrentIterVals.swap(NextIterVals);
8995 const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
8998 PHINode *PN = getConstantEvolvingPHI(Cond, L);
8999 if (!PN) return getCouldNotCompute();
9001 // If the loop is canonicalized, the PHI will have exactly two entries.
9002 // That's the only form we support here.
9003 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9005 DenseMap<Instruction *, Constant *> CurrentIterVals;
9006 BasicBlock *Header = L->getHeader();
9007 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9009 BasicBlock *Latch = L->getLoopLatch();
9010 assert(Latch && "Should follow from NumIncomingValues == 2!");
9012 for (PHINode &PHI : Header->phis()) {
9013 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9014 CurrentIterVals[&PHI] = StartCST;
9016 if (!CurrentIterVals.count(PN))
9017 return getCouldNotCompute();
9019 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9020 // the loop symbolically to determine when the condition gets a value of
9022 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9023 const DataLayout &DL = getDataLayout();
9024 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9025 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9026 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9028 // Couldn't symbolically evaluate.
9029 if (!CondVal) return getCouldNotCompute();
9031 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9032 ++NumBruteForceTripCountsComputed;
9033 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9036 // Update all the PHI nodes for the next iteration.
9037 DenseMap<Instruction *, Constant *> NextIterVals;
9039 // Create a list of which PHIs we need to compute. We want to do this before
9040 // calling EvaluateExpression on them because that may invalidate iterators
9041 // into CurrentIterVals.
9042 SmallVector<PHINode *, 8> PHIsToCompute;
9043 for (const auto &I : CurrentIterVals) {
9044 PHINode *PHI = dyn_cast<PHINode>(I.first);
9045 if (!PHI || PHI->getParent() != Header) continue;
9046 PHIsToCompute.push_back(PHI);
9048 for (PHINode *PHI : PHIsToCompute) {
9049 Constant *&NextPHI = NextIterVals[PHI];
9050 if (NextPHI) continue; // Already computed!
9052 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9053 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9055 CurrentIterVals.swap(NextIterVals);
9058 // Too many iterations were needed to evaluate.
9059 return getCouldNotCompute();
9062 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9063 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
9065 // Check to see if we've folded this expression at this loop before.
9066 for (auto &LS : Values)
9068 return LS.second ? LS.second : V;
9070 Values.emplace_back(L, nullptr);
9072 // Otherwise compute it.
9073 const SCEV *C = computeSCEVAtScope(V, L);
9074 for (auto &LS : reverse(ValuesAtScopes[V]))
9075 if (LS.first == L) {
9077 if (!isa<SCEVConstant>(C))
9078 ValuesAtScopesUsers[C].push_back({L, V});
9084 /// This builds up a Constant using the ConstantExpr interface. That way, we
9085 /// will return Constants for objects which aren't represented by a
9086 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9087 /// Returns NULL if the SCEV isn't representable as a Constant.
9088 static Constant *BuildConstantFromSCEV(const SCEV *V) {
9089 switch (V->getSCEVType()) {
9090 case scCouldNotCompute:
9094 return cast<SCEVConstant>(V)->getValue();
9096 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9097 case scSignExtend: {
9098 const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
9099 if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
9100 return ConstantExpr::getSExt(CastOp, SS->getType());
9103 case scZeroExtend: {
9104 const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
9105 if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
9106 return ConstantExpr::getZExt(CastOp, SZ->getType());
9110 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9111 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9112 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9117 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9118 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9119 return ConstantExpr::getTrunc(CastOp, ST->getType());
9123 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9124 if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
9125 if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
9126 unsigned AS = PTy->getAddressSpace();
9127 Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
9128 C = ConstantExpr::getBitCast(C, DestPtrTy);
9130 for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
9131 Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
9136 if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
9137 unsigned AS = C2->getType()->getPointerAddressSpace();
9139 Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
9140 // The offsets have been converted to bytes. We can add bytes to an
9141 // i8* by GEP with the byte count in the first index.
9142 C = ConstantExpr::getBitCast(C, DestPtrTy);
9145 // Don't bother trying to sum two pointers. We probably can't
9146 // statically compute a load that results from it anyway.
9147 if (C2->getType()->isPointerTy())
9150 if (C->getType()->isPointerTy()) {
9151 C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()),
9154 C = ConstantExpr::getAdd(C, C2);
9162 const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
9163 if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
9164 // Don't bother with pointers at all.
9165 if (C->getType()->isPointerTy())
9167 for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
9168 Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
9169 if (!C2 || C2->getType()->isPointerTy())
9171 C = ConstantExpr::getMul(C, C2);
9178 const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
9179 if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
9180 if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
9181 if (LHS->getType() == RHS->getType())
9182 return ConstantExpr::getUDiv(LHS, RHS);
9189 case scSequentialUMinExpr:
9190 return nullptr; // TODO: smax, umax, smin, umax, umin_seq.
9192 llvm_unreachable("Unknown SCEV kind!");
9195 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9196 if (isa<SCEVConstant>(V)) return V;
9198 // If this instruction is evolved from a constant-evolving PHI, compute the
9199 // exit value from the loop without using SCEVs.
9200 if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
9201 if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
9202 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9203 const Loop *CurrLoop = this->LI[I->getParent()];
9204 // Looking for loop exit value.
9205 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9206 PN->getParent() == CurrLoop->getHeader()) {
9207 // Okay, there is no closed form solution for the PHI node. Check
9208 // to see if the loop that contains it has a known backedge-taken
9209 // count. If so, we may be able to force computation of the exit
9211 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9212 // This trivial case can show up in some degenerate cases where
9213 // the incoming IR has not yet been fully simplified.
9214 if (BackedgeTakenCount->isZero()) {
9215 Value *InitValue = nullptr;
9216 bool MultipleInitValues = false;
9217 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9218 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9220 InitValue = PN->getIncomingValue(i);
9221 else if (InitValue != PN->getIncomingValue(i)) {
9222 MultipleInitValues = true;
9227 if (!MultipleInitValues && InitValue)
9228 return getSCEV(InitValue);
9230 // Do we have a loop invariant value flowing around the backedge
9231 // for a loop which must execute the backedge?
9232 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
9233 isKnownPositive(BackedgeTakenCount) &&
9234 PN->getNumIncomingValues() == 2) {
9236 unsigned InLoopPred =
9237 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
9238 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
9239 if (CurrLoop->isLoopInvariant(BackedgeVal))
9240 return getSCEV(BackedgeVal);
9242 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
9243 // Okay, we know how many times the containing loop executes. If
9244 // this is a constant evolving PHI node, get the final value at
9245 // the specified iteration number.
9246 Constant *RV = getConstantEvolutionLoopExitValue(
9247 PN, BTCC->getAPInt(), CurrLoop);
9248 if (RV) return getSCEV(RV);
9252 // If there is a single-input Phi, evaluate it at our scope. If we can
9253 // prove that this replacement does not break LCSSA form, use new value.
9254 if (PN->getNumOperands() == 1) {
9255 const SCEV *Input = getSCEV(PN->getOperand(0));
9256 const SCEV *InputAtScope = getSCEVAtScope(Input, L);
9257 // TODO: We can generalize it using LI.replacementPreservesLCSSAForm,
9258 // for the simplest case just support constants.
9259 if (isa<SCEVConstant>(InputAtScope)) return InputAtScope;
9263 // Okay, this is an expression that we cannot symbolically evaluate
9264 // into a SCEV. Check to see if it's possible to symbolically evaluate
9265 // the arguments into constants, and if so, try to constant propagate the
9266 // result. This is particularly useful for computing loop exit values.
9267 if (CanConstantFold(I)) {
9268 SmallVector<Constant *, 4> Operands;
9269 bool MadeImprovement = false;
9270 for (Value *Op : I->operands()) {
9271 if (Constant *C = dyn_cast<Constant>(Op)) {
9272 Operands.push_back(C);
9276 // If any of the operands is non-constant and if they are
9277 // non-integer and non-pointer, don't even try to analyze them
9278 // with scev techniques.
9279 if (!isSCEVable(Op->getType()))
9282 const SCEV *OrigV = getSCEV(Op);
9283 const SCEV *OpV = getSCEVAtScope(OrigV, L);
9284 MadeImprovement |= OrigV != OpV;
9286 Constant *C = BuildConstantFromSCEV(OpV);
9288 if (C->getType() != Op->getType())
9289 C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
9293 Operands.push_back(C);
9296 // Check to see if getSCEVAtScope actually made an improvement.
9297 if (MadeImprovement) {
9298 Constant *C = nullptr;
9299 const DataLayout &DL = getDataLayout();
9300 if (const CmpInst *CI = dyn_cast<CmpInst>(I))
9301 C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
9302 Operands[1], DL, &TLI);
9303 else if (const LoadInst *Load = dyn_cast<LoadInst>(I)) {
9304 if (!Load->isVolatile())
9305 C = ConstantFoldLoadFromConstPtr(Operands[0], Load->getType(),
9308 C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
9315 // This is some other type of SCEVUnknown, just return it.
9319 if (isa<SCEVCommutativeExpr>(V) || isa<SCEVSequentialMinMaxExpr>(V)) {
9320 const auto *Comm = cast<SCEVNAryExpr>(V);
9321 // Avoid performing the look-up in the common case where the specified
9322 // expression has no loop-variant portions.
9323 for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
9324 const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
9325 if (OpAtScope != Comm->getOperand(i)) {
9326 // Okay, at least one of these operands is loop variant but might be
9327 // foldable. Build a new instance of the folded commutative expression.
9328 SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
9329 Comm->op_begin()+i);
9330 NewOps.push_back(OpAtScope);
9332 for (++i; i != e; ++i) {
9333 OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
9334 NewOps.push_back(OpAtScope);
9336 if (isa<SCEVAddExpr>(Comm))
9337 return getAddExpr(NewOps, Comm->getNoWrapFlags());
9338 if (isa<SCEVMulExpr>(Comm))
9339 return getMulExpr(NewOps, Comm->getNoWrapFlags());
9340 if (isa<SCEVMinMaxExpr>(Comm))
9341 return getMinMaxExpr(Comm->getSCEVType(), NewOps);
9342 if (isa<SCEVSequentialMinMaxExpr>(Comm))
9343 return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
9344 llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
9347 // If we got here, all operands are loop invariant.
9351 if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
9352 const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
9353 const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
9354 if (LHS == Div->getLHS() && RHS == Div->getRHS())
9355 return Div; // must be loop invariant
9356 return getUDivExpr(LHS, RHS);
9359 // If this is a loop recurrence for a loop that does not contain L, then we
9360 // are dealing with the final value computed by the loop.
9361 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
9362 // First, attempt to evaluate each operand.
9363 // Avoid performing the look-up in the common case where the specified
9364 // expression has no loop-variant portions.
9365 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9366 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9367 if (OpAtScope == AddRec->getOperand(i))
9370 // Okay, at least one of these operands is loop variant but might be
9371 // foldable. Build a new instance of the folded commutative expression.
9372 SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
9373 AddRec->op_begin()+i);
9374 NewOps.push_back(OpAtScope);
9375 for (++i; i != e; ++i)
9376 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9378 const SCEV *FoldedRec =
9379 getAddRecExpr(NewOps, AddRec->getLoop(),
9380 AddRec->getNoWrapFlags(SCEV::FlagNW));
9381 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9382 // The addrec may be folded to a nonrecurrence, for example, if the
9383 // induction variable is multiplied by zero after constant folding. Go
9384 // ahead and return the folded value.
9390 // If the scope is outside the addrec's loop, evaluate it by using the
9391 // loop exit value of the addrec.
9392 if (!AddRec->getLoop()->contains(L)) {
9393 // To evaluate this recurrence, we need to know how many times the AddRec
9394 // loop iterates. Compute this now.
9395 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9396 if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
9398 // Then, evaluate the AddRec.
9399 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9405 if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
9406 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
9407 if (Op == Cast->getOperand())
9408 return Cast; // must be loop invariant
9409 return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
9412 llvm_unreachable("Unknown SCEV type!");
9415 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
9416 return getSCEVAtScope(getSCEV(V), L);
9419 const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
9420 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
9421 return stripInjectiveFunctions(ZExt->getOperand());
9422 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
9423 return stripInjectiveFunctions(SExt->getOperand());
9427 /// Finds the minimum unsigned root of the following equation:
9429 /// A * X = B (mod N)
9431 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
9432 /// A and B isn't important.
9434 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
9435 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
9436 ScalarEvolution &SE) {
9437 uint32_t BW = A.getBitWidth();
9438 assert(BW == SE.getTypeSizeInBits(B->getType()));
9439 assert(A != 0 && "A must be non-zero.");
9443 // The gcd of A and N may have only one prime factor: 2. The number of
9444 // trailing zeros in A is its multiplicity
9445 uint32_t Mult2 = A.countTrailingZeros();
9448 // 2. Check if B is divisible by D.
9450 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
9451 // is not less than multiplicity of this prime factor for D.
9452 if (SE.GetMinTrailingZeros(B) < Mult2)
9453 return SE.getCouldNotCompute();
9455 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
9458 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
9459 // (N / D) in general. The inverse itself always fits into BW bits, though,
9460 // so we immediately truncate it.
9461 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
9462 APInt Mod(BW + 1, 0);
9463 Mod.setBit(BW - Mult2); // Mod = N / D
9464 APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
9466 // 4. Compute the minimum unsigned root of the equation:
9467 // I * (B / D) mod (N / D)
9468 // To simplify the computation, we factor out the divide by D:
9469 // (I * B mod N) / D
9470 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
9471 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
9474 /// For a given quadratic addrec, generate coefficients of the corresponding
9475 /// quadratic equation, multiplied by a common value to ensure that they are
9477 /// The returned value is a tuple { A, B, C, M, BitWidth }, where
9478 /// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
9479 /// were multiplied by, and BitWidth is the bit width of the original addrec
9481 /// This function returns None if the addrec coefficients are not compile-
9483 static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
9484 GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
9485 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
9486 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
9487 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
9488 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
9489 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
9490 << *AddRec << '\n');
9492 // We currently can only solve this if the coefficients are constants.
9493 if (!LC || !MC || !NC) {
9494 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
9498 APInt L = LC->getAPInt();
9499 APInt M = MC->getAPInt();
9500 APInt N = NC->getAPInt();
9501 assert(!N.isZero() && "This is not a quadratic addrec");
9503 unsigned BitWidth = LC->getAPInt().getBitWidth();
9504 unsigned NewWidth = BitWidth + 1;
9505 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
9506 << BitWidth << '\n');
9507 // The sign-extension (as opposed to a zero-extension) here matches the
9508 // extension used in SolveQuadraticEquationWrap (with the same motivation).
9509 N = N.sext(NewWidth);
9510 M = M.sext(NewWidth);
9511 L = L.sext(NewWidth);
9513 // The increments are M, M+N, M+2N, ..., so the accumulated values are
9514 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
9515 // L+M, L+2M+N, L+3M+3N, ...
9516 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
9518 // The equation Acc = 0 is then
9519 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
9520 // In a quadratic form it becomes:
9521 // N n^2 + (2M-N) n + 2L = 0.
9524 APInt B = 2 * M - A;
9526 APInt T = APInt(NewWidth, 2);
9527 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
9528 << "x + " << C << ", coeff bw: " << NewWidth
9529 << ", multiplied by " << T << '\n');
9530 return std::make_tuple(A, B, C, T, BitWidth);
9533 /// Helper function to compare optional APInts:
9534 /// (a) if X and Y both exist, return min(X, Y),
9535 /// (b) if neither X nor Y exist, return None,
9536 /// (c) if exactly one of X and Y exists, return that value.
9537 static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) {
9538 if (X.hasValue() && Y.hasValue()) {
9539 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
9540 APInt XW = X->sextOrSelf(W);
9541 APInt YW = Y->sextOrSelf(W);
9542 return XW.slt(YW) ? *X : *Y;
9544 if (!X.hasValue() && !Y.hasValue())
9546 return X.hasValue() ? *X : *Y;
9549 /// Helper function to truncate an optional APInt to a given BitWidth.
9550 /// When solving addrec-related equations, it is preferable to return a value
9551 /// that has the same bit width as the original addrec's coefficients. If the
9552 /// solution fits in the original bit width, truncate it (except for i1).
9553 /// Returning a value of a different bit width may inhibit some optimizations.
9555 /// In general, a solution to a quadratic equation generated from an addrec
9556 /// may require BW+1 bits, where BW is the bit width of the addrec's
9557 /// coefficients. The reason is that the coefficients of the quadratic
9558 /// equation are BW+1 bits wide (to avoid truncation when converting from
9559 /// the addrec to the equation).
9560 static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) {
9563 unsigned W = X->getBitWidth();
9564 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
9565 return X->trunc(BitWidth);
9569 /// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
9570 /// iterations. The values L, M, N are assumed to be signed, and they
9571 /// should all have the same bit widths.
9572 /// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
9573 /// where BW is the bit width of the addrec's coefficients.
9574 /// If the calculated value is a BW-bit integer (for BW > 1), it will be
9575 /// returned as such, otherwise the bit width of the returned value may
9576 /// be greater than BW.
9578 /// This function returns None if
9579 /// (a) the addrec coefficients are not constant, or
9580 /// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
9581 /// like x^2 = 5, no integer solutions exist, in other cases an integer
9582 /// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
9583 static Optional<APInt>
9584 SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
9587 auto T = GetQuadraticEquation(AddRec);
9591 std::tie(A, B, C, M, BitWidth) = *T;
9592 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
9593 Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1);
9597 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
9598 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
9602 return TruncIfPossible(X, BitWidth);
9605 /// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
9606 /// iterations. The values M, N are assumed to be signed, and they
9607 /// should all have the same bit widths.
9608 /// Find the least n such that c(n) does not belong to the given range,
9609 /// while c(n-1) does.
9611 /// This function returns None if
9612 /// (a) the addrec coefficients are not constant, or
9613 /// (b) SolveQuadraticEquationWrap was unable to find a solution for the
9614 /// bounds of the range.
9615 static Optional<APInt>
9616 SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
9617 const ConstantRange &Range, ScalarEvolution &SE) {
9618 assert(AddRec->getOperand(0)->isZero() &&
9619 "Starting value of addrec should be 0");
9620 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
9621 << Range << ", addrec " << *AddRec << '\n');
9622 // This case is handled in getNumIterationsInRange. Here we can assume that
9623 // we start in the range.
9624 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
9625 "Addrec's initial value should be in range");
9629 auto T = GetQuadraticEquation(AddRec);
9633 // Be careful about the return value: there can be two reasons for not
9634 // returning an actual number. First, if no solutions to the equations
9635 // were found, and second, if the solutions don't leave the given range.
9636 // The first case means that the actual solution is "unknown", the second
9637 // means that it's known, but not valid. If the solution is unknown, we
9638 // cannot make any conclusions.
9639 // Return a pair: the optional solution and a flag indicating if the
9640 // solution was found.
9641 auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> {
9642 // Solve for signed overflow and unsigned overflow, pick the lower
9644 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
9645 << Bound << " (before multiplying by " << M << ")\n");
9646 Bound *= M; // The quadratic equation multiplier.
9648 Optional<APInt> SO = None;
9650 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
9651 "signed overflow\n");
9652 SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
9654 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
9655 "unsigned overflow\n");
9656 Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound,
9659 auto LeavesRange = [&] (const APInt &X) {
9660 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
9661 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
9662 if (Range.contains(V0->getValue()))
9664 // X should be at least 1, so X-1 is non-negative.
9665 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
9666 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
9667 if (Range.contains(V1->getValue()))
9672 // If SolveQuadraticEquationWrap returns None, it means that there can
9673 // be a solution, but the function failed to find it. We cannot treat it
9674 // as "no solution".
9675 if (!SO.hasValue() || !UO.hasValue())
9676 return { None, false };
9678 // Check the smaller value first to see if it leaves the range.
9679 // At this point, both SO and UO must have values.
9680 Optional<APInt> Min = MinOptional(SO, UO);
9681 if (LeavesRange(*Min))
9682 return { Min, true };
9683 Optional<APInt> Max = Min == SO ? UO : SO;
9684 if (LeavesRange(*Max))
9685 return { Max, true };
9687 // Solutions were found, but were eliminated, hence the "true".
9688 return { None, true };
9691 std::tie(A, B, C, M, BitWidth) = *T;
9692 // Lower bound is inclusive, subtract 1 to represent the exiting value.
9693 APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1;
9694 APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth());
9695 auto SL = SolveForBoundary(Lower);
9696 auto SU = SolveForBoundary(Upper);
9697 // If any of the solutions was unknown, no meaninigful conclusions can
9699 if (!SL.second || !SU.second)
9702 // Claim: The correct solution is not some value between Min and Max.
9704 // Justification: Assuming that Min and Max are different values, one of
9705 // them is when the first signed overflow happens, the other is when the
9706 // first unsigned overflow happens. Crossing the range boundary is only
9707 // possible via an overflow (treating 0 as a special case of it, modeling
9708 // an overflow as crossing k*2^W for some k).
9710 // The interesting case here is when Min was eliminated as an invalid
9711 // solution, but Max was not. The argument is that if there was another
9712 // overflow between Min and Max, it would also have been eliminated if
9713 // it was considered.
9715 // For a given boundary, it is possible to have two overflows of the same
9716 // type (signed/unsigned) without having the other type in between: this
9717 // can happen when the vertex of the parabola is between the iterations
9718 // corresponding to the overflows. This is only possible when the two
9719 // overflows cross k*2^W for the same k. In such case, if the second one
9720 // left the range (and was the first one to do so), the first overflow
9721 // would have to enter the range, which would mean that either we had left
9722 // the range before or that we started outside of it. Both of these cases
9723 // are contradictions.
9725 // Claim: In the case where SolveForBoundary returns None, the correct
9726 // solution is not some value between the Max for this boundary and the
9727 // Min of the other boundary.
9729 // Justification: Assume that we had such Max_A and Min_B corresponding
9730 // to range boundaries A and B and such that Max_A < Min_B. If there was
9731 // a solution between Max_A and Min_B, it would have to be caused by an
9732 // overflow corresponding to either A or B. It cannot correspond to B,
9733 // since Min_B is the first occurrence of such an overflow. If it
9734 // corresponded to A, it would have to be either a signed or an unsigned
9735 // overflow that is larger than both eliminated overflows for A. But
9736 // between the eliminated overflows and this overflow, the values would
9737 // cover the entire value space, thus crossing the other boundary, which
9738 // is a contradiction.
9740 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
9743 ScalarEvolution::ExitLimit
9744 ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
9745 bool AllowPredicates) {
9747 // This is only used for loops with a "x != y" exit test. The exit condition
9748 // is now expressed as a single expression, V = x-y. So the exit test is
9749 // effectively V != 0. We know and take advantage of the fact that this
9750 // expression only being used in a comparison by zero context.
9752 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
9753 // If the value is a constant
9754 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
9755 // If the value is already zero, the branch will execute zero times.
9756 if (C->getValue()->isZero()) return C;
9757 return getCouldNotCompute(); // Otherwise it will loop infinitely.
9760 const SCEVAddRecExpr *AddRec =
9761 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
9763 if (!AddRec && AllowPredicates)
9764 // Try to make this an AddRec using runtime tests, in the first X
9765 // iterations of this loop, where X is the SCEV expression found by the
9767 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
9769 if (!AddRec || AddRec->getLoop() != L)
9770 return getCouldNotCompute();
9772 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
9773 // the quadratic equation to solve it.
9774 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
9775 // We can only use this value if the chrec ends up with an exact zero
9776 // value at this index. When solving for "X*X != 5", for example, we
9777 // should not accept a root of 2.
9778 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
9779 const auto *R = cast<SCEVConstant>(getConstant(S.getValue()));
9780 return ExitLimit(R, R, false, Predicates);
9782 return getCouldNotCompute();
9785 // Otherwise we can only handle this if it is affine.
9786 if (!AddRec->isAffine())
9787 return getCouldNotCompute();
9789 // If this is an affine expression, the execution count of this branch is
9790 // the minimum unsigned root of the following equation:
9792 // Start + Step*N = 0 (mod 2^BW)
9796 // Step*N = -Start (mod 2^BW)
9798 // where BW is the common bit width of Start and Step.
9800 // Get the initial value for the loop.
9801 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
9802 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
9804 // For now we handle only constant steps.
9806 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
9807 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
9808 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
9809 // We have not yet seen any such cases.
9810 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
9811 if (!StepC || StepC->getValue()->isZero())
9812 return getCouldNotCompute();
9814 // For positive steps (counting up until unsigned overflow):
9815 // N = -Start/Step (as unsigned)
9816 // For negative steps (counting down to zero):
9818 // First compute the unsigned distance from zero in the direction of Step.
9819 bool CountDown = StepC->getAPInt().isNegative();
9820 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
9822 // Handle unitary steps, which cannot wraparound.
9823 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
9824 // N = Distance (as unsigned)
9825 if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
9826 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
9827 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
9829 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
9830 // we end up with a loop whose backedge-taken count is n - 1. Detect this
9831 // case, and see if we can improve the bound.
9833 // Explicitly handling this here is necessary because getUnsignedRange
9834 // isn't context-sensitive; it doesn't know that we only care about the
9835 // range inside the loop.
9836 const SCEV *Zero = getZero(Distance->getType());
9837 const SCEV *One = getOne(Distance->getType());
9838 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
9839 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
9840 // If Distance + 1 doesn't overflow, we can compute the maximum distance
9841 // as "unsigned_max(Distance + 1) - 1".
9842 ConstantRange CR = getUnsignedRange(DistancePlusOne);
9843 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
9845 return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
9848 // If the condition controls loop exit (the loop exits only if the expression
9849 // is true) and the addition is no-wrap we can use unsigned divide to
9850 // compute the backedge count. In this case, the step may not divide the
9851 // distance, but we don't care because if the condition is "missed" the loop
9852 // will have undefined behavior due to wrapping.
9853 if (ControlsExit && AddRec->hasNoSelfWrap() &&
9854 loopHasNoAbnormalExits(AddRec->getLoop())) {
9856 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
9857 const SCEV *Max = getCouldNotCompute();
9858 if (Exact != getCouldNotCompute()) {
9859 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
9860 Max = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
9862 return ExitLimit(Exact, Max, false, Predicates);
9865 // Solve the general equation.
9866 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
9867 getNegativeSCEV(Start), *this);
9870 if (E != getCouldNotCompute()) {
9871 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
9872 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
9874 return ExitLimit(E, M, false, Predicates);
9877 ScalarEvolution::ExitLimit
9878 ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
9879 // Loops that look like: while (X == 0) are very strange indeed. We don't
9880 // handle them yet except for the trivial case. This could be expanded in the
9881 // future as needed.
9883 // If the value is a constant, check to see if it is known to be non-zero
9884 // already. If so, the backedge will execute zero times.
9885 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
9886 if (!C->getValue()->isZero())
9887 return getZero(C->getType());
9888 return getCouldNotCompute(); // Otherwise it will loop infinitely.
9891 // We could implement others, but I really doubt anyone writes loops like
9892 // this, and if they did, they would already be constant folded.
9893 return getCouldNotCompute();
9896 std::pair<const BasicBlock *, const BasicBlock *>
9897 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
9899 // If the block has a unique predecessor, then there is no path from the
9900 // predecessor to the block that does not go through the direct edge
9901 // from the predecessor to the block.
9902 if (const BasicBlock *Pred = BB->getSinglePredecessor())
9905 // A loop's header is defined to be a block that dominates the loop.
9906 // If the header has a unique predecessor outside the loop, it must be
9907 // a block that has exactly one successor that can reach the loop.
9908 if (const Loop *L = LI.getLoopFor(BB))
9909 return {L->getLoopPredecessor(), L->getHeader()};
9911 return {nullptr, nullptr};
9914 /// SCEV structural equivalence is usually sufficient for testing whether two
9915 /// expressions are equal, however for the purposes of looking for a condition
9916 /// guarding a loop, it can be useful to be a little more general, since a
9917 /// front-end may have replicated the controlling expression.
9918 static bool HasSameValue(const SCEV *A, const SCEV *B) {
9919 // Quick check to see if they are the same SCEV.
9920 if (A == B) return true;
9922 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
9923 // Not all instructions that are "identical" compute the same value. For
9924 // instance, two distinct alloca instructions allocating the same type are
9925 // identical and do not read memory; but compute distinct values.
9926 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
9929 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
9930 // two different instructions with the same value. Check for this case.
9931 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
9932 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
9933 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
9934 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
9935 if (ComputesEqualValues(AI, BI))
9938 // Otherwise assume they may have a different value.
9942 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
9943 const SCEV *&LHS, const SCEV *&RHS,
9945 bool ControllingFiniteLoop) {
9946 bool Changed = false;
9947 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
9949 auto TrivialCase = [&](bool TriviallyTrue) {
9950 LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
9951 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
9954 // If we hit the max recursion limit bail out.
9958 // Canonicalize a constant to the right side.
9959 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
9960 // Check for both operands constant.
9961 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
9962 if (ConstantExpr::getICmp(Pred,
9964 RHSC->getValue())->isNullValue())
9965 return TrivialCase(false);
9967 return TrivialCase(true);
9969 // Otherwise swap the operands to put the constant on the right.
9970 std::swap(LHS, RHS);
9971 Pred = ICmpInst::getSwappedPredicate(Pred);
9975 // If we're comparing an addrec with a value which is loop-invariant in the
9976 // addrec's loop, put the addrec on the left. Also make a dominance check,
9977 // as both operands could be addrecs loop-invariant in each other's loop.
9978 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
9979 const Loop *L = AR->getLoop();
9980 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
9981 std::swap(LHS, RHS);
9982 Pred = ICmpInst::getSwappedPredicate(Pred);
9987 // If there's a constant operand, canonicalize comparisons with boundary
9988 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
9989 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
9990 const APInt &RA = RC->getAPInt();
9992 bool SimplifiedByConstantRange = false;
9994 if (!ICmpInst::isEquality(Pred)) {
9995 ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
9996 if (ExactCR.isFullSet())
9997 return TrivialCase(true);
9998 else if (ExactCR.isEmptySet())
9999 return TrivialCase(false);
10002 CmpInst::Predicate NewPred;
10003 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10004 ICmpInst::isEquality(NewPred)) {
10005 // We were able to convert an inequality to an equality.
10007 RHS = getConstant(NewRHS);
10008 Changed = SimplifiedByConstantRange = true;
10012 if (!SimplifiedByConstantRange) {
10016 case ICmpInst::ICMP_EQ:
10017 case ICmpInst::ICMP_NE:
10018 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10020 if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
10021 if (const SCEVMulExpr *ME =
10022 dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
10023 if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
10024 ME->getOperand(0)->isAllOnesValue()) {
10025 RHS = AE->getOperand(1);
10026 LHS = ME->getOperand(1);
10032 // The "Should have been caught earlier!" messages refer to the fact
10033 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10034 // should have fired on the corresponding cases, and canonicalized the
10035 // check to trivial case.
10037 case ICmpInst::ICMP_UGE:
10038 assert(!RA.isMinValue() && "Should have been caught earlier!");
10039 Pred = ICmpInst::ICMP_UGT;
10040 RHS = getConstant(RA - 1);
10043 case ICmpInst::ICMP_ULE:
10044 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10045 Pred = ICmpInst::ICMP_ULT;
10046 RHS = getConstant(RA + 1);
10049 case ICmpInst::ICMP_SGE:
10050 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10051 Pred = ICmpInst::ICMP_SGT;
10052 RHS = getConstant(RA - 1);
10055 case ICmpInst::ICMP_SLE:
10056 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10057 Pred = ICmpInst::ICMP_SLT;
10058 RHS = getConstant(RA + 1);
10065 // Check for obvious equality.
10066 if (HasSameValue(LHS, RHS)) {
10067 if (ICmpInst::isTrueWhenEqual(Pred))
10068 return TrivialCase(true);
10069 if (ICmpInst::isFalseWhenEqual(Pred))
10070 return TrivialCase(false);
10073 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10074 // adding or subtracting 1 from one of the operands. This can be done for
10075 // one of two reasons:
10076 // 1) The range of the RHS does not include the (signed/unsigned) boundaries
10077 // 2) The loop is finite, with this comparison controlling the exit. Since the
10078 // loop is finite, the bound cannot include the corresponding boundary
10079 // (otherwise it would loop forever).
10081 case ICmpInst::ICMP_SLE:
10082 if (ControllingFiniteLoop || !getSignedRangeMax(RHS).isMaxSignedValue()) {
10083 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10085 Pred = ICmpInst::ICMP_SLT;
10087 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10088 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10090 Pred = ICmpInst::ICMP_SLT;
10094 case ICmpInst::ICMP_SGE:
10095 if (ControllingFiniteLoop || !getSignedRangeMin(RHS).isMinSignedValue()) {
10096 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10098 Pred = ICmpInst::ICMP_SGT;
10100 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10101 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10103 Pred = ICmpInst::ICMP_SGT;
10107 case ICmpInst::ICMP_ULE:
10108 if (ControllingFiniteLoop || !getUnsignedRangeMax(RHS).isMaxValue()) {
10109 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10111 Pred = ICmpInst::ICMP_ULT;
10113 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10114 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10115 Pred = ICmpInst::ICMP_ULT;
10119 case ICmpInst::ICMP_UGE:
10120 if (ControllingFiniteLoop || !getUnsignedRangeMin(RHS).isMinValue()) {
10121 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10122 Pred = ICmpInst::ICMP_UGT;
10124 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10125 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10127 Pred = ICmpInst::ICMP_UGT;
10135 // TODO: More simplifications are possible here.
10137 // Recursively simplify until we either hit a recursion limit or nothing
10140 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1,
10141 ControllingFiniteLoop);
10146 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
10147 return getSignedRangeMax(S).isNegative();
10150 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
10151 return getSignedRangeMin(S).isStrictlyPositive();
10154 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
10155 return !getSignedRangeMin(S).isNegative();
10158 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
10159 return !getSignedRangeMax(S).isStrictlyPositive();
10162 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
10163 return getUnsignedRangeMin(S) != 0;
10166 std::pair<const SCEV *, const SCEV *>
10167 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
10168 // Compute SCEV on entry of loop L.
10169 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10170 if (Start == getCouldNotCompute())
10171 return { Start, Start };
10172 // Compute post increment SCEV for loop L.
10173 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10174 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10175 return { Start, PostInc };
10178 bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
10179 const SCEV *LHS, const SCEV *RHS) {
10180 // First collect all loops.
10181 SmallPtrSet<const Loop *, 8> LoopsUsed;
10182 getUsedLoops(LHS, LoopsUsed);
10183 getUsedLoops(RHS, LoopsUsed);
10185 if (LoopsUsed.empty())
10188 // Domination relationship must be a linear order on collected loops.
10190 for (auto *L1 : LoopsUsed)
10191 for (auto *L2 : LoopsUsed)
10192 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10193 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10194 "Domination relationship is not a linear order");
10198 *std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
10199 [&](const Loop *L1, const Loop *L2) {
10200 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10203 // Get init and post increment value for LHS.
10204 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10205 // if LHS contains unknown non-invariant SCEV then bail out.
10206 if (SplitLHS.first == getCouldNotCompute())
10208 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10209 // Get init and post increment value for RHS.
10210 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10211 // if RHS contains unknown non-invariant SCEV then bail out.
10212 if (SplitRHS.first == getCouldNotCompute())
10214 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10215 // It is possible that init SCEV contains an invariant load but it does
10216 // not dominate MDL and is not available at MDL loop entry, so we should
10218 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10219 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10222 // It seems backedge guard check is faster than entry one so in some cases
10223 // it can speed up whole estimation by short circuit
10224 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10225 SplitRHS.second) &&
10226 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10229 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
10230 const SCEV *LHS, const SCEV *RHS) {
10231 // Canonicalize the inputs first.
10232 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10234 if (isKnownViaInduction(Pred, LHS, RHS))
10237 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10240 // Otherwise see what can be done with some simple reasoning.
10241 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10244 Optional<bool> ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred,
10247 if (isKnownPredicate(Pred, LHS, RHS))
10249 else if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS))
10254 bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred,
10255 const SCEV *LHS, const SCEV *RHS,
10256 const Instruction *CtxI) {
10257 // TODO: Analyze guards and assumes from Context's block.
10258 return isKnownPredicate(Pred, LHS, RHS) ||
10259 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
10262 Optional<bool> ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred,
10265 const Instruction *CtxI) {
10266 Optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10267 if (KnownWithoutContext)
10268 return KnownWithoutContext;
10270 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10272 else if (isBasicBlockEntryGuardedByCond(CtxI->getParent(),
10273 ICmpInst::getInversePredicate(Pred),
10279 bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
10280 const SCEVAddRecExpr *LHS,
10282 const Loop *L = LHS->getLoop();
10283 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10284 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10287 Optional<ScalarEvolution::MonotonicPredicateType>
10288 ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS,
10289 ICmpInst::Predicate Pred) {
10290 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10293 // Verify an invariant: inverting the predicate should turn a monotonically
10294 // increasing change to a monotonically decreasing one, and vice versa.
10296 auto ResultSwapped =
10297 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10299 assert(ResultSwapped.hasValue() && "should be able to analyze both!");
10300 assert(ResultSwapped.getValue() != Result.getValue() &&
10301 "monotonicity should flip as we flip the predicate");
10308 Optional<ScalarEvolution::MonotonicPredicateType>
10309 ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
10310 ICmpInst::Predicate Pred) {
10311 // A zero step value for LHS means the induction variable is essentially a
10312 // loop invariant value. We don't really depend on the predicate actually
10313 // flipping from false to true (for increasing predicates, and the other way
10314 // around for decreasing predicates), all we care about is that *if* the
10315 // predicate changes then it only changes from false to true.
10317 // A zero step value in itself is not very useful, but there may be places
10318 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
10319 // as general as possible.
10321 // Only handle LE/LT/GE/GT predicates.
10322 if (!ICmpInst::isRelational(Pred))
10325 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
10326 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
10327 "Should be greater or less!");
10329 // Check that AR does not wrap.
10330 if (ICmpInst::isUnsigned(Pred)) {
10331 if (!LHS->hasNoUnsignedWrap())
10333 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10335 assert(ICmpInst::isSigned(Pred) &&
10336 "Relational predicate is either signed or unsigned!");
10337 if (!LHS->hasNoSignedWrap())
10340 const SCEV *Step = LHS->getStepRecurrence(*this);
10342 if (isKnownNonNegative(Step))
10343 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10345 if (isKnownNonPositive(Step))
10346 return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10352 Optional<ScalarEvolution::LoopInvariantPredicate>
10353 ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred,
10354 const SCEV *LHS, const SCEV *RHS,
10357 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10358 if (!isLoopInvariant(RHS, L)) {
10359 if (!isLoopInvariant(LHS, L))
10362 std::swap(LHS, RHS);
10363 Pred = ICmpInst::getSwappedPredicate(Pred);
10366 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
10367 if (!ArLHS || ArLHS->getLoop() != L)
10370 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
10371 if (!MonotonicType)
10373 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
10374 // true as the loop iterates, and the backedge is control dependent on
10375 // "ArLHS `Pred` RHS" == true then we can reason as follows:
10377 // * if the predicate was false in the first iteration then the predicate
10378 // is never evaluated again, since the loop exits without taking the
10380 // * if the predicate was true in the first iteration then it will
10381 // continue to be true for all future iterations since it is
10382 // monotonically increasing.
10384 // For both the above possibilities, we can replace the loop varying
10385 // predicate with its value on the first iteration of the loop (which is
10386 // loop invariant).
10388 // A similar reasoning applies for a monotonically decreasing predicate, by
10389 // replacing true with false and false with true in the above two bullets.
10390 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
10391 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
10393 if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
10396 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), RHS);
10399 Optional<ScalarEvolution::LoopInvariantPredicate>
10400 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
10401 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
10402 const Instruction *CtxI, const SCEV *MaxIter) {
10403 // Try to prove the following set of facts:
10404 // - The predicate is monotonic in the iteration space.
10405 // - If the check does not fail on the 1st iteration:
10406 // - No overflow will happen during first MaxIter iterations;
10407 // - It will not fail on the MaxIter'th iteration.
10408 // If the check does fail on the 1st iteration, we leave the loop and no
10409 // other checks matter.
10411 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10412 if (!isLoopInvariant(RHS, L)) {
10413 if (!isLoopInvariant(LHS, L))
10416 std::swap(LHS, RHS);
10417 Pred = ICmpInst::getSwappedPredicate(Pred);
10420 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
10421 if (!AR || AR->getLoop() != L)
10424 // The predicate must be relational (i.e. <, <=, >=, >).
10425 if (!ICmpInst::isRelational(Pred))
10428 // TODO: Support steps other than +/- 1.
10429 const SCEV *Step = AR->getStepRecurrence(*this);
10430 auto *One = getOne(Step->getType());
10431 auto *MinusOne = getNegativeSCEV(One);
10432 if (Step != One && Step != MinusOne)
10435 // Type mismatch here means that MaxIter is potentially larger than max
10436 // unsigned value in start type, which mean we cannot prove no wrap for the
10438 if (AR->getType() != MaxIter->getType())
10441 // Value of IV on suggested last iteration.
10442 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
10443 // Does it still meet the requirement?
10444 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
10446 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
10447 // not exceed max unsigned value of this type), this effectively proves
10448 // that there is no wrap during the iteration. To prove that there is no
10449 // signed/unsigned wrap, we need to check that
10450 // Start <= Last for step = 1 or Start >= Last for step = -1.
10451 ICmpInst::Predicate NoOverflowPred =
10452 CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
10453 if (Step == MinusOne)
10454 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
10455 const SCEV *Start = AR->getStart();
10456 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
10459 // Everything is fine.
10460 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
10463 bool ScalarEvolution::isKnownPredicateViaConstantRanges(
10464 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
10465 if (HasSameValue(LHS, RHS))
10466 return ICmpInst::isTrueWhenEqual(Pred);
10468 // This code is split out from isKnownPredicate because it is called from
10469 // within isLoopEntryGuardedByCond.
10471 auto CheckRanges = [&](const ConstantRange &RangeLHS,
10472 const ConstantRange &RangeRHS) {
10473 return RangeLHS.icmp(Pred, RangeRHS);
10476 // The check at the top of the function catches the case where the values are
10477 // known to be equal.
10478 if (Pred == CmpInst::ICMP_EQ)
10481 if (Pred == CmpInst::ICMP_NE) {
10482 if (CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) ||
10483 CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)))
10485 auto *Diff = getMinusSCEV(LHS, RHS);
10486 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
10489 if (CmpInst::isSigned(Pred))
10490 return CheckRanges(getSignedRange(LHS), getSignedRange(RHS));
10492 return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS));
10495 bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
10498 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
10499 // C1 and C2 are constant integers. If either X or Y are not add expressions,
10500 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
10501 // OutC1 and OutC2.
10502 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
10503 APInt &OutC1, APInt &OutC2,
10504 SCEV::NoWrapFlags ExpectedFlags) {
10505 const SCEV *XNonConstOp, *XConstOp;
10506 const SCEV *YNonConstOp, *YConstOp;
10507 SCEV::NoWrapFlags XFlagsPresent;
10508 SCEV::NoWrapFlags YFlagsPresent;
10510 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
10511 XConstOp = getZero(X->getType());
10513 XFlagsPresent = ExpectedFlags;
10515 if (!isa<SCEVConstant>(XConstOp) ||
10516 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
10519 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
10520 YConstOp = getZero(Y->getType());
10522 YFlagsPresent = ExpectedFlags;
10525 if (!isa<SCEVConstant>(YConstOp) ||
10526 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
10529 if (YNonConstOp != XNonConstOp)
10532 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
10533 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
10545 case ICmpInst::ICMP_SGE:
10546 std::swap(LHS, RHS);
10548 case ICmpInst::ICMP_SLE:
10549 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
10550 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
10555 case ICmpInst::ICMP_SGT:
10556 std::swap(LHS, RHS);
10558 case ICmpInst::ICMP_SLT:
10559 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
10560 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
10565 case ICmpInst::ICMP_UGE:
10566 std::swap(LHS, RHS);
10568 case ICmpInst::ICMP_ULE:
10569 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
10570 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
10575 case ICmpInst::ICMP_UGT:
10576 std::swap(LHS, RHS);
10578 case ICmpInst::ICMP_ULT:
10579 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
10580 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
10588 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
10591 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
10594 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
10595 // the stack can result in exponential time complexity.
10596 SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
10598 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
10600 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
10601 // isKnownPredicate. isKnownPredicate is more powerful, but also more
10602 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
10603 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
10604 // use isKnownPredicate later if needed.
10605 return isKnownNonNegative(RHS) &&
10606 isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
10607 isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
10610 bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
10611 ICmpInst::Predicate Pred,
10612 const SCEV *LHS, const SCEV *RHS) {
10613 // No need to even try if we know the module has no guards.
10617 return any_of(*BB, [&](const Instruction &I) {
10618 using namespace llvm::PatternMatch;
10621 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
10622 m_Value(Condition))) &&
10623 isImpliedCond(Pred, LHS, RHS, Condition, false);
10627 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
10628 /// protected by a conditional between LHS and RHS. This is used to
10629 /// to eliminate casts.
10631 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
10632 ICmpInst::Predicate Pred,
10633 const SCEV *LHS, const SCEV *RHS) {
10634 // Interpret a null as meaning no loop, where there is obviously no guard
10635 // (interprocedural conditions notwithstanding).
10636 if (!L) return true;
10639 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
10640 "This cannot be done on broken IR!");
10643 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
10646 BasicBlock *Latch = L->getLoopLatch();
10650 BranchInst *LoopContinuePredicate =
10651 dyn_cast<BranchInst>(Latch->getTerminator());
10652 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
10653 isImpliedCond(Pred, LHS, RHS,
10654 LoopContinuePredicate->getCondition(),
10655 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
10658 // We don't want more than one activation of the following loops on the stack
10659 // -- that can lead to O(n!) time complexity.
10660 if (WalkingBEDominatingConds)
10663 SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true);
10665 // See if we can exploit a trip count to prove the predicate.
10666 const auto &BETakenInfo = getBackedgeTakenInfo(L);
10667 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
10668 if (LatchBECount != getCouldNotCompute()) {
10669 // We know that Latch branches back to the loop header exactly
10670 // LatchBECount times. This means the backdege condition at Latch is
10671 // equivalent to "{0,+,1} u< LatchBECount".
10672 Type *Ty = LatchBECount->getType();
10673 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
10674 const SCEV *LoopCounter =
10675 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
10676 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
10681 // Check conditions due to any @llvm.assume intrinsics.
10682 for (auto &AssumeVH : AC.assumptions()) {
10685 auto *CI = cast<CallInst>(AssumeVH);
10686 if (!DT.dominates(CI, Latch->getTerminator()))
10689 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
10693 // If the loop is not reachable from the entry block, we risk running into an
10694 // infinite loop as we walk up into the dom tree. These loops do not matter
10695 // anyway, so we just return a conservative answer when we see them.
10696 if (!DT.isReachableFromEntry(L->getHeader()))
10699 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
10702 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
10703 DTN != HeaderDTN; DTN = DTN->getIDom()) {
10704 assert(DTN && "should reach the loop header before reaching the root!");
10706 BasicBlock *BB = DTN->getBlock();
10707 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
10710 BasicBlock *PBB = BB->getSinglePredecessor();
10714 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
10715 if (!ContinuePredicate || !ContinuePredicate->isConditional())
10718 Value *Condition = ContinuePredicate->getCondition();
10720 // If we have an edge `E` within the loop body that dominates the only
10721 // latch, the condition guarding `E` also guards the backedge. This
10722 // reasoning works only for loops with a single latch.
10724 BasicBlockEdge DominatingEdge(PBB, BB);
10725 if (DominatingEdge.isSingleEdge()) {
10726 // We're constructively (and conservatively) enumerating edges within the
10727 // loop body that dominate the latch. The dominator tree better agree
10728 // with us on this:
10729 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
10731 if (isImpliedCond(Pred, LHS, RHS, Condition,
10732 BB != ContinuePredicate->getSuccessor(0)))
10740 bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
10741 ICmpInst::Predicate Pred,
10745 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
10746 "This cannot be done on broken IR!");
10748 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
10749 // the facts (a >= b && a != b) separately. A typical situation is when the
10750 // non-strict comparison is known from ranges and non-equality is known from
10751 // dominating predicates. If we are proving strict comparison, we always try
10752 // to prove non-equality and non-strict comparison separately.
10753 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
10754 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
10755 bool ProvedNonStrictComparison = false;
10756 bool ProvedNonEquality = false;
10758 auto SplitAndProve =
10759 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
10760 if (!ProvedNonStrictComparison)
10761 ProvedNonStrictComparison = Fn(NonStrictPredicate);
10762 if (!ProvedNonEquality)
10763 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
10764 if (ProvedNonStrictComparison && ProvedNonEquality)
10769 if (ProvingStrictComparison) {
10770 auto ProofFn = [&](ICmpInst::Predicate P) {
10771 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
10773 if (SplitAndProve(ProofFn))
10777 // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard.
10778 auto ProveViaGuard = [&](const BasicBlock *Block) {
10779 if (isImpliedViaGuard(Block, Pred, LHS, RHS))
10781 if (ProvingStrictComparison) {
10782 auto ProofFn = [&](ICmpInst::Predicate P) {
10783 return isImpliedViaGuard(Block, P, LHS, RHS);
10785 if (SplitAndProve(ProofFn))
10791 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
10792 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
10793 const Instruction *CtxI = &BB->front();
10794 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
10796 if (ProvingStrictComparison) {
10797 auto ProofFn = [&](ICmpInst::Predicate P) {
10798 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
10800 if (SplitAndProve(ProofFn))
10806 // Starting at the block's predecessor, climb up the predecessor chain, as long
10807 // as there are predecessors that can be found that have unique successors
10808 // leading to the original block.
10809 const Loop *ContainingLoop = LI.getLoopFor(BB);
10810 const BasicBlock *PredBB;
10811 if (ContainingLoop && ContainingLoop->getHeader() == BB)
10812 PredBB = ContainingLoop->getLoopPredecessor();
10814 PredBB = BB->getSinglePredecessor();
10815 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
10816 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
10817 if (ProveViaGuard(Pair.first))
10820 const BranchInst *LoopEntryPredicate =
10821 dyn_cast<BranchInst>(Pair.first->getTerminator());
10822 if (!LoopEntryPredicate ||
10823 LoopEntryPredicate->isUnconditional())
10826 if (ProveViaCond(LoopEntryPredicate->getCondition(),
10827 LoopEntryPredicate->getSuccessor(0) != Pair.second))
10831 // Check conditions due to any @llvm.assume intrinsics.
10832 for (auto &AssumeVH : AC.assumptions()) {
10835 auto *CI = cast<CallInst>(AssumeVH);
10836 if (!DT.dominates(CI, BB))
10839 if (ProveViaCond(CI->getArgOperand(0), false))
10846 bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
10847 ICmpInst::Predicate Pred,
10850 // Interpret a null as meaning no loop, where there is obviously no guard
10851 // (interprocedural conditions notwithstanding).
10855 // Both LHS and RHS must be available at loop entry.
10856 assert(isAvailableAtLoopEntry(LHS, L) &&
10857 "LHS is not available at Loop Entry");
10858 assert(isAvailableAtLoopEntry(RHS, L) &&
10859 "RHS is not available at Loop Entry");
10861 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
10864 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
10867 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
10869 const Value *FoundCondValue, bool Inverse,
10870 const Instruction *CtxI) {
10871 // False conditions implies anything. Do not bother analyzing it further.
10872 if (FoundCondValue ==
10873 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
10876 if (!PendingLoopPredicates.insert(FoundCondValue).second)
10880 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
10882 // Recursively handle And and Or conditions.
10883 const Value *Op0, *Op1;
10884 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
10886 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
10887 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
10888 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
10890 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
10891 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
10894 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
10895 if (!ICI) return false;
10897 // Now that we found a conditional branch that dominates the loop or controls
10898 // the loop latch. Check to see if it is the comparison we are looking for.
10899 ICmpInst::Predicate FoundPred;
10901 FoundPred = ICI->getInversePredicate();
10903 FoundPred = ICI->getPredicate();
10905 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
10906 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
10908 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
10911 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
10913 ICmpInst::Predicate FoundPred,
10914 const SCEV *FoundLHS, const SCEV *FoundRHS,
10915 const Instruction *CtxI) {
10916 // Balance the types.
10917 if (getTypeSizeInBits(LHS->getType()) <
10918 getTypeSizeInBits(FoundLHS->getType())) {
10919 // For unsigned and equality predicates, try to prove that both found
10920 // operands fit into narrow unsigned range. If so, try to prove facts in
10922 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
10923 !FoundRHS->getType()->isPointerTy()) {
10924 auto *NarrowType = LHS->getType();
10925 auto *WideType = FoundLHS->getType();
10926 auto BitWidth = getTypeSizeInBits(NarrowType);
10927 const SCEV *MaxValue = getZeroExtendExpr(
10928 getConstant(APInt::getMaxValue(BitWidth)), WideType);
10929 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
10931 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
10933 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
10934 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
10935 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
10936 TruncFoundRHS, CtxI))
10941 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
10943 if (CmpInst::isSigned(Pred)) {
10944 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
10945 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
10947 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
10948 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
10950 } else if (getTypeSizeInBits(LHS->getType()) >
10951 getTypeSizeInBits(FoundLHS->getType())) {
10952 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
10954 if (CmpInst::isSigned(FoundPred)) {
10955 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
10956 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
10958 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
10959 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
10962 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
10966 bool ScalarEvolution::isImpliedCondBalancedTypes(
10967 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
10968 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
10969 const Instruction *CtxI) {
10970 assert(getTypeSizeInBits(LHS->getType()) ==
10971 getTypeSizeInBits(FoundLHS->getType()) &&
10972 "Types should be balanced!");
10973 // Canonicalize the query to match the way instcombine will have
10974 // canonicalized the comparison.
10975 if (SimplifyICmpOperands(Pred, LHS, RHS))
10977 return CmpInst::isTrueWhenEqual(Pred);
10978 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
10979 if (FoundLHS == FoundRHS)
10980 return CmpInst::isFalseWhenEqual(FoundPred);
10982 // Check to see if we can make the LHS or RHS match.
10983 if (LHS == FoundRHS || RHS == FoundLHS) {
10984 if (isa<SCEVConstant>(RHS)) {
10985 std::swap(FoundLHS, FoundRHS);
10986 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
10988 std::swap(LHS, RHS);
10989 Pred = ICmpInst::getSwappedPredicate(Pred);
10993 // Check whether the found predicate is the same as the desired predicate.
10994 if (FoundPred == Pred)
10995 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
10997 // Check whether swapping the found predicate makes it the same as the
10998 // desired predicate.
10999 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11000 // We can write the implication
11001 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11002 // using one of the following ways:
11003 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11004 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11005 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11006 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11007 // Forms 1. and 2. require swapping the operands of one condition. Don't
11008 // do this if it would break canonical constant/addrec ordering.
11009 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11010 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11012 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11013 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11015 // There's no clear preference between forms 3. and 4., try both. Avoid
11016 // forming getNotSCEV of pointer values as the resulting subtract is
11018 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11019 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11020 FoundLHS, FoundRHS, CtxI))
11023 if (!FoundLHS->getType()->isPointerTy() &&
11024 !FoundRHS->getType()->isPointerTy() &&
11025 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11026 getNotSCEV(FoundRHS), CtxI))
11032 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11033 CmpInst::Predicate P2) {
11034 assert(P1 != P2 && "Handled earlier!");
11035 return CmpInst::isRelational(P2) &&
11036 P1 == CmpInst::getFlippedSignednessPredicate(P2);
11038 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11039 // Unsigned comparison is the same as signed comparison when both the
11040 // operands are non-negative or negative.
11041 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11042 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11043 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11044 // Create local copies that we can freely swap and canonicalize our
11045 // conditions to "le/lt".
11046 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11047 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11048 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11049 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11050 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11051 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11052 std::swap(CanonicalLHS, CanonicalRHS);
11053 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11055 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11057 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11058 ICmpInst::isLE(CanonicalFoundPred)) &&
11060 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11061 // Use implication:
11062 // x <u y && y >=s 0 --> x <s y.
11063 // If we can prove the left part, the right part is also proven.
11064 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11065 CanonicalRHS, CanonicalFoundLHS,
11066 CanonicalFoundRHS);
11067 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11068 // Use implication:
11069 // x <s y && y <s 0 --> x <u y.
11070 // If we can prove the left part, the right part is also proven.
11071 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11072 CanonicalRHS, CanonicalFoundLHS,
11073 CanonicalFoundRHS);
11076 // Check if we can make progress by sharpening ranges.
11077 if (FoundPred == ICmpInst::ICMP_NE &&
11078 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11080 const SCEVConstant *C = nullptr;
11081 const SCEV *V = nullptr;
11083 if (isa<SCEVConstant>(FoundLHS)) {
11084 C = cast<SCEVConstant>(FoundLHS);
11087 C = cast<SCEVConstant>(FoundRHS);
11091 // The guarding predicate tells us that C != V. If the known range
11092 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11093 // range we consider has to correspond to same signedness as the
11094 // predicate we're interested in folding.
11096 APInt Min = ICmpInst::isSigned(Pred) ?
11097 getSignedRangeMin(V) : getUnsignedRangeMin(V);
11099 if (Min == C->getAPInt()) {
11100 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11101 // This is true even if (Min + 1) wraps around -- in case of
11102 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11104 APInt SharperMin = Min + 1;
11107 case ICmpInst::ICMP_SGE:
11108 case ICmpInst::ICMP_UGE:
11109 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11110 // RHS, we're done.
11111 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11116 case ICmpInst::ICMP_SGT:
11117 case ICmpInst::ICMP_UGT:
11118 // We know from the range information that (V `Pred` Min ||
11119 // V == Min). We know from the guarding condition that !(V
11120 // == Min). This gives us
11122 // V `Pred` Min || V == Min && !(V == Min)
11125 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11127 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11131 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11132 case ICmpInst::ICMP_SLE:
11133 case ICmpInst::ICMP_ULE:
11134 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11135 LHS, V, getConstant(SharperMin), CtxI))
11139 case ICmpInst::ICMP_SLT:
11140 case ICmpInst::ICMP_ULT:
11141 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11142 LHS, V, getConstant(Min), CtxI))
11153 // Check whether the actual condition is beyond sufficient.
11154 if (FoundPred == ICmpInst::ICMP_EQ)
11155 if (ICmpInst::isTrueWhenEqual(Pred))
11156 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11158 if (Pred == ICmpInst::ICMP_NE)
11159 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11160 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11163 // Otherwise assume the worst.
11167 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11168 const SCEV *&L, const SCEV *&R,
11169 SCEV::NoWrapFlags &Flags) {
11170 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11171 if (!AE || AE->getNumOperands() != 2)
11174 L = AE->getOperand(0);
11175 R = AE->getOperand(1);
11176 Flags = AE->getNoWrapFlags();
11180 Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
11181 const SCEV *Less) {
11182 // We avoid subtracting expressions here because this function is usually
11183 // fairly deep in the call stack (i.e. is called many times).
11187 return APInt(getTypeSizeInBits(More->getType()), 0);
11189 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11190 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11191 const auto *MAR = cast<SCEVAddRecExpr>(More);
11193 if (LAR->getLoop() != MAR->getLoop())
11196 // We look at affine expressions only; not for correctness but to keep
11197 // getStepRecurrence cheap.
11198 if (!LAR->isAffine() || !MAR->isAffine())
11201 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11204 Less = LAR->getStart();
11205 More = MAR->getStart();
11210 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11211 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11212 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11216 SCEV::NoWrapFlags Flags;
11217 const SCEV *LLess = nullptr, *RLess = nullptr;
11218 const SCEV *LMore = nullptr, *RMore = nullptr;
11219 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11220 // Compare (X + C1) vs X.
11221 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11222 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11224 return -(C1->getAPInt());
11226 // Compare X vs (X + C2).
11227 if (splitBinaryAdd(More, LMore, RMore, Flags))
11228 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11230 return C2->getAPInt();
11232 // Compare (X + C1) vs (X + C2).
11233 if (C1 && C2 && RLess == RMore)
11234 return C2->getAPInt() - C1->getAPInt();
11239 bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
11240 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11241 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11242 // Try to recognize the following pattern:
11247 // FoundLHS = {Start,+,W}
11248 // context_bb: // Basic block from the same loop
11249 // known(Pred, FoundLHS, FoundRHS)
11251 // If some predicate is known in the context of a loop, it is also known on
11252 // each iteration of this loop, including the first iteration. Therefore, in
11253 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
11254 // prove the original pred using this fact.
11257 const BasicBlock *ContextBB = CtxI->getParent();
11258 // Make sure AR varies in the context block.
11259 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
11260 const Loop *L = AR->getLoop();
11261 // Make sure that context belongs to the loop and executes on 1st iteration
11262 // (if it ever executes at all).
11263 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11265 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
11267 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
11270 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
11271 const Loop *L = AR->getLoop();
11272 // Make sure that context belongs to the loop and executes on 1st iteration
11273 // (if it ever executes at all).
11274 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11276 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
11278 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
11284 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
11285 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11286 const SCEV *FoundLHS, const SCEV *FoundRHS) {
11287 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
11290 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11294 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
11295 if (!AddRecFoundLHS)
11298 // We'd like to let SCEV reason about control dependencies, so we constrain
11299 // both the inequalities to be about add recurrences on the same loop. This
11300 // way we can use isLoopEntryGuardedByCond later.
11302 const Loop *L = AddRecFoundLHS->getLoop();
11303 if (L != AddRecLHS->getLoop())
11306 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
11308 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
11311 // Informal proof for (2), assuming (1) [*]:
11313 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
11317 // FoundLHS s< FoundRHS s< INT_MIN - C
11318 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
11319 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
11320 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
11321 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
11322 // <=> FoundLHS + C s< FoundRHS + C
11324 // [*]: (1) can be proved by ruling out overflow.
11326 // [**]: This can be proved by analyzing all the four possibilities:
11327 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
11328 // (A s>= 0, B s>= 0).
11331 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
11332 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
11333 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
11334 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
11335 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
11338 Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
11339 Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
11340 if (!LDiff || !RDiff || *LDiff != *RDiff)
11343 if (LDiff->isMinValue())
11346 APInt FoundRHSLimit;
11348 if (Pred == CmpInst::ICMP_ULT) {
11349 FoundRHSLimit = -(*RDiff);
11351 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
11352 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
11355 // Try to prove (1) or (2), as needed.
11356 return isAvailableAtLoopEntry(FoundRHS, L) &&
11357 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
11358 getConstant(FoundRHSLimit));
11361 bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
11362 const SCEV *LHS, const SCEV *RHS,
11363 const SCEV *FoundLHS,
11364 const SCEV *FoundRHS, unsigned Depth) {
11365 const PHINode *LPhi = nullptr, *RPhi = nullptr;
11367 auto ClearOnExit = make_scope_exit([&]() {
11369 bool Erased = PendingMerges.erase(LPhi);
11370 assert(Erased && "Failed to erase LPhi!");
11374 bool Erased = PendingMerges.erase(RPhi);
11375 assert(Erased && "Failed to erase RPhi!");
11380 // Find respective Phis and check that they are not being pending.
11381 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
11382 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
11383 if (!PendingMerges.insert(Phi).second)
11387 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
11388 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
11389 // If we detect a loop of Phi nodes being processed by this method, for
11392 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
11393 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
11395 // we don't want to deal with a case that complex, so return conservative
11397 if (!PendingMerges.insert(Phi).second)
11402 // If none of LHS, RHS is a Phi, nothing to do here.
11403 if (!LPhi && !RPhi)
11406 // If there is a SCEVUnknown Phi we are interested in, make it left.
11408 std::swap(LHS, RHS);
11409 std::swap(FoundLHS, FoundRHS);
11410 std::swap(LPhi, RPhi);
11411 Pred = ICmpInst::getSwappedPredicate(Pred);
11414 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
11415 const BasicBlock *LBB = LPhi->getParent();
11416 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
11418 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
11419 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
11420 isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
11421 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
11424 if (RPhi && RPhi->getParent() == LBB) {
11425 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
11426 // If we compare two Phis from the same block, and for each entry block
11427 // the predicate is true for incoming values from this block, then the
11428 // predicate is also true for the Phis.
11429 for (const BasicBlock *IncBB : predecessors(LBB)) {
11430 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
11431 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
11432 if (!ProvedEasily(L, R))
11435 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
11436 // Case two: RHS is also a Phi from the same basic block, and it is an
11437 // AddRec. It means that there is a loop which has both AddRec and Unknown
11438 // PHIs, for it we can compare incoming values of AddRec from above the loop
11439 // and latch with their respective incoming values of LPhi.
11440 // TODO: Generalize to handle loops with many inputs in a header.
11441 if (LPhi->getNumIncomingValues() != 2) return false;
11443 auto *RLoop = RAR->getLoop();
11444 auto *Predecessor = RLoop->getLoopPredecessor();
11445 assert(Predecessor && "Loop with AddRec with no predecessor?");
11446 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
11447 if (!ProvedEasily(L1, RAR->getStart()))
11449 auto *Latch = RLoop->getLoopLatch();
11450 assert(Latch && "Loop with AddRec with no latch?");
11451 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
11452 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
11455 // In all other cases go over inputs of LHS and compare each of them to RHS,
11456 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
11457 // At this point RHS is either a non-Phi, or it is a Phi from some block
11458 // different from LBB.
11459 for (const BasicBlock *IncBB : predecessors(LBB)) {
11460 // Check that RHS is available in this block.
11461 if (!dominates(RHS, IncBB))
11463 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
11464 // Make sure L does not refer to a value from a potentially previous
11465 // iteration of a loop.
11466 if (!properlyDominates(L, IncBB))
11468 if (!ProvedEasily(L, RHS))
11475 bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
11478 const SCEV *FoundLHS,
11479 const SCEV *FoundRHS) {
11480 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
11481 // sure that we are dealing with same LHS.
11482 if (RHS == FoundRHS) {
11483 std::swap(LHS, RHS);
11484 std::swap(FoundLHS, FoundRHS);
11485 Pred = ICmpInst::getSwappedPredicate(Pred);
11487 if (LHS != FoundLHS)
11490 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
11494 Value *Shiftee, *ShiftValue;
11496 using namespace PatternMatch;
11497 if (match(SUFoundRHS->getValue(),
11498 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
11499 auto *ShifteeS = getSCEV(Shiftee);
11500 // Prove one of the following:
11501 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
11502 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
11503 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
11505 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
11506 // ---> LHS <=s RHS
11507 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
11508 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
11509 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
11510 if (isKnownNonNegative(ShifteeS))
11511 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
11517 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
11518 const SCEV *LHS, const SCEV *RHS,
11519 const SCEV *FoundLHS,
11520 const SCEV *FoundRHS,
11521 const Instruction *CtxI) {
11522 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
11525 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
11528 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
11531 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
11535 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
11536 FoundLHS, FoundRHS);
11539 /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
11540 template <typename MinMaxExprType>
11541 static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
11542 const SCEV *Candidate) {
11543 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
11547 return is_contained(MinMaxExpr->operands(), Candidate);
11550 static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
11551 ICmpInst::Predicate Pred,
11552 const SCEV *LHS, const SCEV *RHS) {
11553 // If both sides are affine addrecs for the same loop, with equal
11554 // steps, and we know the recurrences don't wrap, then we only
11555 // need to check the predicate on the starting values.
11557 if (!ICmpInst::isRelational(Pred))
11560 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
11563 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
11566 if (LAR->getLoop() != RAR->getLoop())
11568 if (!LAR->isAffine() || !RAR->isAffine())
11571 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
11574 SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
11575 SCEV::FlagNSW : SCEV::FlagNUW;
11576 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
11579 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
11582 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
11584 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
11585 ICmpInst::Predicate Pred,
11586 const SCEV *LHS, const SCEV *RHS) {
11591 case ICmpInst::ICMP_SGE:
11592 std::swap(LHS, RHS);
11594 case ICmpInst::ICMP_SLE:
11596 // min(A, ...) <= A
11597 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
11598 // A <= max(A, ...)
11599 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
11601 case ICmpInst::ICMP_UGE:
11602 std::swap(LHS, RHS);
11604 case ICmpInst::ICMP_ULE:
11606 // min(A, ...) <= A
11607 // FIXME: what about umin_seq?
11608 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
11609 // A <= max(A, ...)
11610 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
11613 llvm_unreachable("covered switch fell through?!");
11616 bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
11617 const SCEV *LHS, const SCEV *RHS,
11618 const SCEV *FoundLHS,
11619 const SCEV *FoundRHS,
11621 assert(getTypeSizeInBits(LHS->getType()) ==
11622 getTypeSizeInBits(RHS->getType()) &&
11623 "LHS and RHS have different sizes?");
11624 assert(getTypeSizeInBits(FoundLHS->getType()) ==
11625 getTypeSizeInBits(FoundRHS->getType()) &&
11626 "FoundLHS and FoundRHS have different sizes?");
11627 // We want to avoid hurting the compile time with analysis of too big trees.
11628 if (Depth > MaxSCEVOperationsImplicationDepth)
11631 // We only want to work with GT comparison so far.
11632 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
11633 Pred = CmpInst::getSwappedPredicate(Pred);
11634 std::swap(LHS, RHS);
11635 std::swap(FoundLHS, FoundRHS);
11638 // For unsigned, try to reduce it to corresponding signed comparison.
11639 if (Pred == ICmpInst::ICMP_UGT)
11640 // We can replace unsigned predicate with its signed counterpart if all
11641 // involved values are non-negative.
11642 // TODO: We could have better support for unsigned.
11643 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
11644 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
11645 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
11646 // use this fact to prove that LHS and RHS are non-negative.
11647 const SCEV *MinusOne = getMinusOne(LHS->getType());
11648 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
11650 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
11652 Pred = ICmpInst::ICMP_SGT;
11655 if (Pred != ICmpInst::ICMP_SGT)
11658 auto GetOpFromSExt = [&](const SCEV *S) {
11659 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
11660 return Ext->getOperand();
11661 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
11662 // the constant in some cases.
11666 // Acquire values from extensions.
11667 auto *OrigLHS = LHS;
11668 auto *OrigFoundLHS = FoundLHS;
11669 LHS = GetOpFromSExt(LHS);
11670 FoundLHS = GetOpFromSExt(FoundLHS);
11672 // Is the SGT predicate can be proved trivially or using the found context.
11673 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
11674 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
11675 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
11676 FoundRHS, Depth + 1);
11679 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
11680 // We want to avoid creation of any new non-constant SCEV. Since we are
11681 // going to compare the operands to RHS, we should be certain that we don't
11682 // need any size extensions for this. So let's decline all cases when the
11683 // sizes of types of LHS and RHS do not match.
11684 // TODO: Maybe try to get RHS from sext to catch more cases?
11685 if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType()))
11688 // Should not overflow.
11689 if (!LHSAddExpr->hasNoSignedWrap())
11692 auto *LL = LHSAddExpr->getOperand(0);
11693 auto *LR = LHSAddExpr->getOperand(1);
11694 auto *MinusOne = getMinusOne(RHS->getType());
11696 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
11697 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
11698 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
11700 // Try to prove the following rule:
11701 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
11702 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
11703 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
11705 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
11707 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
11709 using namespace llvm::PatternMatch;
11711 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
11712 // Rules for division.
11713 // We are going to perform some comparisons with Denominator and its
11714 // derivative expressions. In general case, creating a SCEV for it may
11715 // lead to a complex analysis of the entire graph, and in particular it
11716 // can request trip count recalculation for the same loop. This would
11717 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
11718 // this, we only want to create SCEVs that are constants in this section.
11719 // So we bail if Denominator is not a constant.
11720 if (!isa<ConstantInt>(LR))
11723 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
11725 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
11726 // then a SCEV for the numerator already exists and matches with FoundLHS.
11727 auto *Numerator = getExistingSCEV(LL);
11728 if (!Numerator || Numerator->getType() != FoundLHS->getType())
11731 // Make sure that the numerator matches with FoundLHS and the denominator
11733 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
11736 auto *DTy = Denominator->getType();
11737 auto *FRHSTy = FoundRHS->getType();
11738 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
11739 // One of types is a pointer and another one is not. We cannot extend
11740 // them properly to a wider type, so let us just reject this case.
11741 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
11742 // to avoid this check.
11746 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
11747 auto *WTy = getWiderType(DTy, FRHSTy);
11748 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
11749 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
11751 // Try to prove the following rule:
11752 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
11753 // For example, given that FoundLHS > 2. It means that FoundLHS is at
11754 // least 3. If we divide it by Denominator < 4, we will have at least 1.
11755 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
11756 if (isKnownNonPositive(RHS) &&
11757 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
11760 // Try to prove the following rule:
11761 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
11762 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
11763 // If we divide it by Denominator > 2, then:
11764 // 1. If FoundLHS is negative, then the result is 0.
11765 // 2. If FoundLHS is non-negative, then the result is non-negative.
11766 // Anyways, the result is non-negative.
11767 auto *MinusOne = getMinusOne(WTy);
11768 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
11769 if (isKnownNegative(RHS) &&
11770 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
11775 // If our expression contained SCEVUnknown Phis, and we split it down and now
11776 // need to prove something for them, try to prove the predicate for every
11777 // possible incoming values of those Phis.
11778 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
11784 static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
11785 const SCEV *LHS, const SCEV *RHS) {
11786 // zext x u<= sext x, sext x s<= zext x
11788 case ICmpInst::ICMP_SGE:
11789 std::swap(LHS, RHS);
11791 case ICmpInst::ICMP_SLE: {
11792 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
11793 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
11794 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
11795 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
11799 case ICmpInst::ICMP_UGE:
11800 std::swap(LHS, RHS);
11802 case ICmpInst::ICMP_ULE: {
11803 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
11804 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
11805 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
11806 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
11817 ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
11818 const SCEV *LHS, const SCEV *RHS) {
11819 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
11820 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
11821 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
11822 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
11823 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
11827 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
11828 const SCEV *LHS, const SCEV *RHS,
11829 const SCEV *FoundLHS,
11830 const SCEV *FoundRHS) {
11832 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
11833 case ICmpInst::ICMP_EQ:
11834 case ICmpInst::ICMP_NE:
11835 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
11838 case ICmpInst::ICMP_SLT:
11839 case ICmpInst::ICMP_SLE:
11840 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
11841 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
11844 case ICmpInst::ICMP_SGT:
11845 case ICmpInst::ICMP_SGE:
11846 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
11847 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
11850 case ICmpInst::ICMP_ULT:
11851 case ICmpInst::ICMP_ULE:
11852 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
11853 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
11856 case ICmpInst::ICMP_UGT:
11857 case ICmpInst::ICMP_UGE:
11858 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
11859 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
11864 // Maybe it can be proved via operations?
11865 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
11871 bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
11874 const SCEV *FoundLHS,
11875 const SCEV *FoundRHS) {
11876 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
11877 // The restriction on `FoundRHS` be lifted easily -- it exists only to
11878 // reduce the compile time impact of this optimization.
11881 Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
11885 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
11887 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
11888 // antecedent "`FoundLHS` `Pred` `FoundRHS`".
11889 ConstantRange FoundLHSRange =
11890 ConstantRange::makeExactICmpRegion(Pred, ConstFoundRHS);
11892 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
11893 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
11895 // We can also compute the range of values for `LHS` that satisfy the
11896 // consequent, "`LHS` `Pred` `RHS`":
11897 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
11898 // The antecedent implies the consequent if every value of `LHS` that
11899 // satisfies the antecedent also satisfies the consequent.
11900 return LHSRange.icmp(Pred, ConstRHS);
11903 bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
11905 assert(isKnownPositive(Stride) && "Positive stride expected!");
11907 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
11908 const SCEV *One = getOne(Stride->getType());
11911 APInt MaxRHS = getSignedRangeMax(RHS);
11912 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
11913 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
11915 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
11916 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
11919 APInt MaxRHS = getUnsignedRangeMax(RHS);
11920 APInt MaxValue = APInt::getMaxValue(BitWidth);
11921 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
11923 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
11924 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
11927 bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
11930 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
11931 const SCEV *One = getOne(Stride->getType());
11934 APInt MinRHS = getSignedRangeMin(RHS);
11935 APInt MinValue = APInt::getSignedMinValue(BitWidth);
11936 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
11938 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
11939 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
11942 APInt MinRHS = getUnsignedRangeMin(RHS);
11943 APInt MinValue = APInt::getMinValue(BitWidth);
11944 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
11946 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
11947 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
11950 const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
11951 // umin(N, 1) + floor((N - umin(N, 1)) / D)
11952 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
11953 // expression fixes the case of N=0.
11954 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
11955 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
11956 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
11959 const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
11960 const SCEV *Stride,
11964 // The logic in this function assumes we can represent a positive stride.
11965 // If we can't, the backedge-taken count must be zero.
11966 if (IsSigned && BitWidth == 1)
11967 return getZero(Stride->getType());
11969 // This code has only been closely audited for negative strides in the
11970 // unsigned comparison case, it may be correct for signed comparison, but
11971 // that needs to be established.
11972 assert((!IsSigned || !isKnownNonPositive(Stride)) &&
11973 "Stride is expected strictly positive for signed case!");
11975 // Calculate the maximum backedge count based on the range of values
11976 // permitted by Start, End, and Stride.
11978 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
11981 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
11983 // We assume either the stride is positive, or the backedge-taken count
11984 // is zero. So force StrideForMaxBECount to be at least one.
11985 APInt One(BitWidth, 1);
11986 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
11987 : APIntOps::umax(One, MinStride);
11989 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
11990 : APInt::getMaxValue(BitWidth);
11991 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
11993 // Although End can be a MAX expression we estimate MaxEnd considering only
11994 // the case End = RHS of the loop termination condition. This is safe because
11995 // in the other case (End - Start) is zero, leading to a zero maximum backedge
11997 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
11998 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12000 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12001 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12002 : APIntOps::umax(MaxEnd, MinStart);
12004 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12005 getConstant(StrideForMaxBECount) /* Step */);
12008 ScalarEvolution::ExitLimit
12009 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12010 const Loop *L, bool IsSigned,
12011 bool ControlsExit, bool AllowPredicates) {
12012 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
12014 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12015 bool PredicatedIV = false;
12017 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12018 // Can we prove this loop *must* be UB if overflow of IV occurs?
12019 // Reasoning goes as follows:
12020 // * Suppose the IV did self wrap.
12021 // * If Stride evenly divides the iteration space, then once wrap
12022 // occurs, the loop must revisit the same values.
12023 // * We know that RHS is invariant, and that none of those values
12024 // caused this exit to be taken previously. Thus, this exit is
12025 // dynamically dead.
12026 // * If this is the sole exit, then a dead exit implies the loop
12027 // must be infinite if there are no abnormal exits.
12028 // * If the loop were infinite, then it must either not be mustprogress
12029 // or have side effects. Otherwise, it must be UB.
12030 // * It can't (by assumption), be UB so we have contradicted our
12031 // premise and can conclude the IV did not in fact self-wrap.
12032 if (!isLoopInvariant(RHS, L))
12035 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12036 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12039 if (!ControlsExit || !loopHasNoAbnormalExits(L))
12042 return loopIsFiniteByAssumption(L);
12046 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12047 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12048 if (AR && AR->getLoop() == L && AR->isAffine()) {
12049 auto canProveNUW = [&]() {
12050 if (!isLoopInvariant(RHS, L))
12053 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12054 // We need the sequence defined by AR to strictly increase in the
12055 // unsigned integer domain for the logic below to hold.
12058 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12059 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12060 // If RHS <=u Limit, then there must exist a value V in the sequence
12061 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12062 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12063 // overflow occurs. This limit also implies that a signed comparison
12064 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12065 // the high bits on both sides must be zero.
12066 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12067 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12068 Limit = Limit.zext(OuterBitWidth);
12069 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12071 auto Flags = AR->getNoWrapFlags();
12072 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12073 Flags = setFlags(Flags, SCEV::FlagNUW);
12075 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12076 if (AR->hasNoUnsignedWrap()) {
12077 // Emulate what getZeroExtendExpr would have done during construction
12078 // if we'd been able to infer the fact just above at that time.
12079 const SCEV *Step = AR->getStepRecurrence(*this);
12080 Type *Ty = ZExt->getType();
12081 auto *S = getAddRecExpr(
12082 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12083 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12084 IV = dyn_cast<SCEVAddRecExpr>(S);
12091 if (!IV && AllowPredicates) {
12092 // Try to make this an AddRec using runtime tests, in the first X
12093 // iterations of this loop, where X is the SCEV expression found by the
12094 // algorithm below.
12095 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12096 PredicatedIV = true;
12099 // Avoid weird loops
12100 if (!IV || IV->getLoop() != L || !IV->isAffine())
12101 return getCouldNotCompute();
12103 // A precondition of this method is that the condition being analyzed
12104 // reaches an exiting branch which dominates the latch. Given that, we can
12105 // assume that an increment which violates the nowrap specification and
12106 // produces poison must cause undefined behavior when the resulting poison
12107 // value is branched upon and thus we can conclude that the backedge is
12108 // taken no more often than would be required to produce that poison value.
12109 // Note that a well defined loop can exit on the iteration which violates
12110 // the nowrap specification if there is another exit (either explicit or
12111 // implicit/exceptional) which causes the loop to execute before the
12112 // exiting instruction we're analyzing would trigger UB.
12113 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12114 bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
12115 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
12117 const SCEV *Stride = IV->getStepRecurrence(*this);
12119 bool PositiveStride = isKnownPositive(Stride);
12121 // Avoid negative or zero stride values.
12122 if (!PositiveStride) {
12123 // We can compute the correct backedge taken count for loops with unknown
12124 // strides if we can prove that the loop is not an infinite loop with side
12125 // effects. Here's the loop structure we are trying to handle -
12131 // } while (i < end);
12133 // The backedge taken count for such loops is evaluated as -
12134 // (max(end, start + stride) - start - 1) /u stride
12136 // The additional preconditions that we need to check to prove correctness
12137 // of the above formula is as follows -
12139 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12141 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12142 // no side effects within the loop)
12143 // c) loop has a single static exit (with no abnormal exits)
12145 // Precondition a) implies that if the stride is negative, this is a single
12146 // trip loop. The backedge taken count formula reduces to zero in this case.
12148 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12149 // then a zero stride means the backedge can't be taken without executing
12150 // undefined behavior.
12152 // The positive stride case is the same as isKnownPositive(Stride) returning
12153 // true (original behavior of the function).
12155 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12156 !loopHasNoAbnormalExits(L))
12157 return getCouldNotCompute();
12159 // This bailout is protecting the logic in computeMaxBECountForLT which
12160 // has not yet been sufficiently auditted or tested with negative strides.
12161 // We used to filter out all known-non-positive cases here, we're in the
12162 // process of being less restrictive bit by bit.
12163 if (IsSigned && isKnownNonPositive(Stride))
12164 return getCouldNotCompute();
12166 if (!isKnownNonZero(Stride)) {
12167 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12168 // if it might eventually be greater than start and if so, on which
12169 // iteration. We can't even produce a useful upper bound.
12170 if (!isLoopInvariant(RHS, L))
12171 return getCouldNotCompute();
12173 // We allow a potentially zero stride, but we need to divide by stride
12174 // below. Since the loop can't be infinite and this check must control
12175 // the sole exit, we can infer the exit must be taken on the first
12176 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12177 // we know the numerator in the divides below must be zero, so we can
12178 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12179 // and produce the right result.
12180 // FIXME: Handle the case where Stride is poison?
12181 auto wouldZeroStrideBeUB = [&]() {
12182 // Proof by contradiction. Suppose the stride were zero. If we can
12183 // prove that the backedge *is* taken on the first iteration, then since
12184 // we know this condition controls the sole exit, we must have an
12185 // infinite loop. We can't have a (well defined) infinite loop per
12186 // check just above.
12187 // Note: The (Start - Stride) term is used to get the start' term from
12188 // (start' + stride,+,stride). Remember that we only care about the
12189 // result of this expression when stride == 0 at runtime.
12190 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12191 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12193 if (!wouldZeroStrideBeUB()) {
12194 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12197 } else if (!Stride->isOne() && !NoWrap) {
12198 auto isUBOnWrap = [&]() {
12199 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12200 // follows trivially from the fact that every (un)signed-wrapped, but
12201 // not self-wrapped value must be LT than the last value before
12202 // (un)signed wrap. Since we know that last value didn't exit, nor
12203 // will any smaller one.
12204 return canAssumeNoSelfWrap(IV);
12207 // Avoid proven overflow cases: this will ensure that the backedge taken
12208 // count will not generate any unsigned overflow. Relaxed no-overflow
12209 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12210 // undefined behaviors like the case of C language.
12211 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12212 return getCouldNotCompute();
12215 // On all paths just preceeding, we established the following invariant:
12216 // IV can be assumed not to overflow up to and including the exiting
12217 // iteration. We proved this in one of two ways:
12218 // 1) We can show overflow doesn't occur before the exiting iteration
12219 // 1a) canIVOverflowOnLT, and b) step of one
12220 // 2) We can show that if overflow occurs, the loop must execute UB
12221 // before any possible exit.
12222 // Note that we have not yet proved RHS invariant (in general).
12224 const SCEV *Start = IV->getStart();
12226 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12227 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12228 // Use integer-typed versions for actual computation; we can't subtract
12229 // pointers in general.
12230 const SCEV *OrigStart = Start;
12231 const SCEV *OrigRHS = RHS;
12232 if (Start->getType()->isPointerTy()) {
12233 Start = getLosslessPtrToIntExpr(Start);
12234 if (isa<SCEVCouldNotCompute>(Start))
12237 if (RHS->getType()->isPointerTy()) {
12238 RHS = getLosslessPtrToIntExpr(RHS);
12239 if (isa<SCEVCouldNotCompute>(RHS))
12243 // When the RHS is not invariant, we do not know the end bound of the loop and
12244 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
12245 // calculate the MaxBECount, given the start, stride and max value for the end
12246 // bound of the loop (RHS), and the fact that IV does not overflow (which is
12248 if (!isLoopInvariant(RHS, L)) {
12249 const SCEV *MaxBECount = computeMaxBECountForLT(
12250 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12251 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
12252 false /*MaxOrZero*/, Predicates);
12255 // We use the expression (max(End,Start)-Start)/Stride to describe the
12256 // backedge count, as if the backedge is taken at least once max(End,Start)
12257 // is End and so the result is as above, and if not max(End,Start) is Start
12258 // so we get a backedge count of zero.
12259 const SCEV *BECount = nullptr;
12260 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
12261 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
12262 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
12263 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
12264 // Can we prove (max(RHS,Start) > Start - Stride?
12265 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
12266 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
12267 // In this case, we can use a refined formula for computing backedge taken
12268 // count. The general formula remains:
12269 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
12270 // We want to use the alternate formula:
12271 // "((End - 1) - (Start - Stride)) /u Stride"
12272 // Let's do a quick case analysis to show these are equivalent under
12273 // our precondition that max(RHS,Start) > Start - Stride.
12274 // * For RHS <= Start, the backedge-taken count must be zero.
12275 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12276 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
12277 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
12278 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
12279 // this to the stride of 1 case.
12280 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
12281 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12282 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
12283 // "((RHS - (Start - Stride) - 1) /u Stride".
12284 // Our preconditions trivially imply no overflow in that form.
12285 const SCEV *MinusOne = getMinusOne(Stride->getType());
12286 const SCEV *Numerator =
12287 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
12288 BECount = getUDivExpr(Numerator, Stride);
12291 const SCEV *BECountIfBackedgeTaken = nullptr;
12293 auto canProveRHSGreaterThanEqualStart = [&]() {
12294 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
12295 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart))
12298 // (RHS > Start - 1) implies RHS >= Start.
12299 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
12300 // "Start - 1" doesn't overflow.
12301 // * For signed comparison, if Start - 1 does overflow, it's equal
12302 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
12303 // * For unsigned comparison, if Start - 1 does overflow, it's equal
12304 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
12306 // FIXME: Should isLoopEntryGuardedByCond do this for us?
12307 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
12308 auto *StartMinusOne = getAddExpr(OrigStart,
12309 getMinusOne(OrigStart->getType()));
12310 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
12313 // If we know that RHS >= Start in the context of loop, then we know that
12314 // max(RHS, Start) = RHS at this point.
12316 if (canProveRHSGreaterThanEqualStart()) {
12319 // If RHS < Start, the backedge will be taken zero times. So in
12320 // general, we can write the backedge-taken count as:
12322 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
12324 // We convert it to the following to make it more convenient for SCEV:
12326 // ceil(max(RHS, Start) - Start) / Stride
12327 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
12329 // See what would happen if we assume the backedge is taken. This is
12330 // used to compute MaxBECount.
12331 BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
12334 // At this point, we know:
12336 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
12337 // 2. The index variable doesn't overflow.
12339 // Therefore, we know N exists such that
12340 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
12341 // doesn't overflow.
12343 // Using this information, try to prove whether the addition in
12344 // "(Start - End) + (Stride - 1)" has unsigned overflow.
12345 const SCEV *One = getOne(Stride->getType());
12346 bool MayAddOverflow = [&] {
12347 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
12348 if (StrideC->getAPInt().isPowerOf2()) {
12349 // Suppose Stride is a power of two, and Start/End are unsigned
12350 // integers. Let UMAX be the largest representable unsigned
12353 // By the preconditions of this function, we know
12354 // "(Start + Stride * N) >= End", and this doesn't overflow.
12357 // End <= (Start + Stride * N) <= UMAX
12359 // Subtracting Start from all the terms:
12361 // End - Start <= Stride * N <= UMAX - Start
12363 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
12365 // End - Start <= Stride * N <= UMAX
12367 // Stride * N is a multiple of Stride. Therefore,
12369 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
12371 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
12372 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
12374 // End - Start <= Stride * N <= UMAX - Stride - 1
12376 // Dropping the middle term:
12378 // End - Start <= UMAX - Stride - 1
12380 // Adding Stride - 1 to both sides:
12382 // (End - Start) + (Stride - 1) <= UMAX
12384 // In other words, the addition doesn't have unsigned overflow.
12386 // A similar proof works if we treat Start/End as signed values.
12387 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
12388 // use signed max instead of unsigned max. Note that we're trying
12389 // to prove a lack of unsigned overflow in either case.
12393 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
12394 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
12395 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
12396 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
12398 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
12404 const SCEV *Delta = getMinusSCEV(End, Start);
12405 if (!MayAddOverflow) {
12406 // floor((D + (S - 1)) / S)
12407 // We prefer this formulation if it's legal because it's fewer operations.
12409 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
12411 BECount = getUDivCeilSCEV(Delta, Stride);
12415 const SCEV *MaxBECount;
12416 bool MaxOrZero = false;
12417 if (isa<SCEVConstant>(BECount)) {
12418 MaxBECount = BECount;
12419 } else if (BECountIfBackedgeTaken &&
12420 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
12421 // If we know exactly how many times the backedge will be taken if it's
12422 // taken at least once, then the backedge count will either be that or
12424 MaxBECount = BECountIfBackedgeTaken;
12427 MaxBECount = computeMaxBECountForLT(
12428 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12431 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
12432 !isa<SCEVCouldNotCompute>(BECount))
12433 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
12435 return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
12438 ScalarEvolution::ExitLimit
12439 ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
12440 const Loop *L, bool IsSigned,
12441 bool ControlsExit, bool AllowPredicates) {
12442 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
12443 // We handle only IV > Invariant
12444 if (!isLoopInvariant(RHS, L))
12445 return getCouldNotCompute();
12447 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12448 if (!IV && AllowPredicates)
12449 // Try to make this an AddRec using runtime tests, in the first X
12450 // iterations of this loop, where X is the SCEV expression found by the
12451 // algorithm below.
12452 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12454 // Avoid weird loops
12455 if (!IV || IV->getLoop() != L || !IV->isAffine())
12456 return getCouldNotCompute();
12458 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12459 bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
12460 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
12462 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
12464 // Avoid negative or zero stride values
12465 if (!isKnownPositive(Stride))
12466 return getCouldNotCompute();
12468 // Avoid proven overflow cases: this will ensure that the backedge taken count
12469 // will not generate any unsigned overflow. Relaxed no-overflow conditions
12470 // exploit NoWrapFlags, allowing to optimize in presence of undefined
12471 // behaviors like the case of C language.
12472 if (!Stride->isOne() && !NoWrap)
12473 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
12474 return getCouldNotCompute();
12476 const SCEV *Start = IV->getStart();
12477 const SCEV *End = RHS;
12478 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
12479 // If we know that Start >= RHS in the context of loop, then we know that
12480 // min(RHS, Start) = RHS at this point.
12481 if (isLoopEntryGuardedByCond(
12482 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
12485 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
12488 if (Start->getType()->isPointerTy()) {
12489 Start = getLosslessPtrToIntExpr(Start);
12490 if (isa<SCEVCouldNotCompute>(Start))
12493 if (End->getType()->isPointerTy()) {
12494 End = getLosslessPtrToIntExpr(End);
12495 if (isa<SCEVCouldNotCompute>(End))
12499 // Compute ((Start - End) + (Stride - 1)) / Stride.
12500 // FIXME: This can overflow. Holding off on fixing this for now;
12501 // howManyGreaterThans will hopefully be gone soon.
12502 const SCEV *One = getOne(Stride->getType());
12503 const SCEV *BECount = getUDivExpr(
12504 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
12506 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
12507 : getUnsignedRangeMax(Start);
12509 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
12510 : getUnsignedRangeMin(Stride);
12512 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
12513 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
12514 : APInt::getMinValue(BitWidth) + (MinStride - 1);
12516 // Although End can be a MIN expression we estimate MinEnd considering only
12517 // the case End = RHS. This is safe because in the other case (Start - End)
12518 // is zero, leading to a zero maximum backedge taken count.
12520 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
12521 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
12523 const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
12525 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
12526 getConstant(MinStride));
12528 if (isa<SCEVCouldNotCompute>(MaxBECount))
12529 MaxBECount = BECount;
12531 return ExitLimit(BECount, MaxBECount, false, Predicates);
12534 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
12535 ScalarEvolution &SE) const {
12536 if (Range.isFullSet()) // Infinite loop.
12537 return SE.getCouldNotCompute();
12539 // If the start is a non-zero constant, shift the range to simplify things.
12540 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
12541 if (!SC->getValue()->isZero()) {
12542 SmallVector<const SCEV *, 4> Operands(operands());
12543 Operands[0] = SE.getZero(SC->getType());
12544 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
12545 getNoWrapFlags(FlagNW));
12546 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
12547 return ShiftedAddRec->getNumIterationsInRange(
12548 Range.subtract(SC->getAPInt()), SE);
12549 // This is strange and shouldn't happen.
12550 return SE.getCouldNotCompute();
12553 // The only time we can solve this is when we have all constant indices.
12554 // Otherwise, we cannot determine the overflow conditions.
12555 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
12556 return SE.getCouldNotCompute();
12558 // Okay at this point we know that all elements of the chrec are constants and
12559 // that the start element is zero.
12561 // First check to see if the range contains zero. If not, the first
12562 // iteration exits.
12563 unsigned BitWidth = SE.getTypeSizeInBits(getType());
12564 if (!Range.contains(APInt(BitWidth, 0)))
12565 return SE.getZero(getType());
12568 // If this is an affine expression then we have this situation:
12569 // Solve {0,+,A} in Range === Ax in Range
12571 // We know that zero is in the range. If A is positive then we know that
12572 // the upper value of the range must be the first possible exit value.
12573 // If A is negative then the lower of the range is the last possible loop
12574 // value. Also note that we already checked for a full range.
12575 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
12576 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
12578 // The exit value should be (End+A)/A.
12579 APInt ExitVal = (End + A).udiv(A);
12580 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
12582 // Evaluate at the exit value. If we really did fall out of the valid
12583 // range, then we computed our trip count, otherwise wrap around or other
12584 // things must have happened.
12585 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
12586 if (Range.contains(Val->getValue()))
12587 return SE.getCouldNotCompute(); // Something strange happened
12589 // Ensure that the previous value is in the range.
12590 assert(Range.contains(
12591 EvaluateConstantChrecAtConstant(this,
12592 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
12593 "Linear scev computation is off in a bad way!");
12594 return SE.getConstant(ExitValue);
12597 if (isQuadratic()) {
12598 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
12599 return SE.getConstant(S.getValue());
12602 return SE.getCouldNotCompute();
12605 const SCEVAddRecExpr *
12606 SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
12607 assert(getNumOperands() > 1 && "AddRec with zero step?");
12608 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
12609 // but in this case we cannot guarantee that the value returned will be an
12610 // AddRec because SCEV does not have a fixed point where it stops
12611 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
12612 // may happen if we reach arithmetic depth limit while simplifying. So we
12613 // construct the returned value explicitly.
12614 SmallVector<const SCEV *, 3> Ops;
12615 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
12616 // (this + Step) is {A+B,+,B+C,+...,+,N}.
12617 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
12618 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
12619 // We know that the last operand is not a constant zero (otherwise it would
12620 // have been popped out earlier). This guarantees us that if the result has
12621 // the same last operand, then it will also not be popped out, meaning that
12622 // the returned value will be an AddRec.
12623 const SCEV *Last = getOperand(getNumOperands() - 1);
12624 assert(!Last->isZero() && "Recurrency with zero step?");
12625 Ops.push_back(Last);
12626 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
12627 SCEV::FlagAnyWrap));
12630 // Return true when S contains at least an undef value.
12631 bool ScalarEvolution::containsUndefs(const SCEV *S) const {
12632 return SCEVExprContains(S, [](const SCEV *S) {
12633 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
12634 return isa<UndefValue>(SU->getValue());
12639 /// Return the size of an element read or written by Inst.
12640 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
12642 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
12643 Ty = Store->getValueOperand()->getType();
12644 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
12645 Ty = Load->getType();
12649 Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
12650 return getSizeOfExpr(ETy, Ty);
12653 //===----------------------------------------------------------------------===//
12654 // SCEVCallbackVH Class Implementation
12655 //===----------------------------------------------------------------------===//
12657 void ScalarEvolution::SCEVCallbackVH::deleted() {
12658 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
12659 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
12660 SE->ConstantEvolutionLoopExitValue.erase(PN);
12661 SE->eraseValueFromMap(getValPtr());
12662 // this now dangles!
12665 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
12666 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
12668 // Forget all the expressions associated with users of the old value,
12669 // so that future queries will recompute the expressions using the new
12671 Value *Old = getValPtr();
12672 SmallVector<User *, 16> Worklist(Old->users());
12673 SmallPtrSet<User *, 8> Visited;
12674 while (!Worklist.empty()) {
12675 User *U = Worklist.pop_back_val();
12676 // Deleting the Old value will cause this to dangle. Postpone
12677 // that until everything else is done.
12680 if (!Visited.insert(U).second)
12682 if (PHINode *PN = dyn_cast<PHINode>(U))
12683 SE->ConstantEvolutionLoopExitValue.erase(PN);
12684 SE->eraseValueFromMap(U);
12685 llvm::append_range(Worklist, U->users());
12687 // Delete the Old value.
12688 if (PHINode *PN = dyn_cast<PHINode>(Old))
12689 SE->ConstantEvolutionLoopExitValue.erase(PN);
12690 SE->eraseValueFromMap(Old);
12691 // this now dangles!
12694 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
12695 : CallbackVH(V), SE(se) {}
12697 //===----------------------------------------------------------------------===//
12698 // ScalarEvolution Class Implementation
12699 //===----------------------------------------------------------------------===//
12701 ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
12702 AssumptionCache &AC, DominatorTree &DT,
12704 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
12705 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
12706 LoopDispositions(64), BlockDispositions(64) {
12707 // To use guards for proving predicates, we need to scan every instruction in
12708 // relevant basic blocks, and not just terminators. Doing this is a waste of
12709 // time if the IR does not actually contain any calls to
12710 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
12712 // This pessimizes the case where a pass that preserves ScalarEvolution wants
12713 // to _add_ guards to the module when there weren't any before, and wants
12714 // ScalarEvolution to optimize based on those guards. For now we prefer to be
12715 // efficient in lieu of being smart in that rather obscure case.
12717 auto *GuardDecl = F.getParent()->getFunction(
12718 Intrinsic::getName(Intrinsic::experimental_guard));
12719 HasGuards = GuardDecl && !GuardDecl->use_empty();
12722 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
12723 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
12724 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
12725 ValueExprMap(std::move(Arg.ValueExprMap)),
12726 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
12727 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
12728 PendingMerges(std::move(Arg.PendingMerges)),
12729 MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
12730 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
12731 PredicatedBackedgeTakenCounts(
12732 std::move(Arg.PredicatedBackedgeTakenCounts)),
12733 BECountUsers(std::move(Arg.BECountUsers)),
12734 ConstantEvolutionLoopExitValue(
12735 std::move(Arg.ConstantEvolutionLoopExitValue)),
12736 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
12737 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
12738 LoopDispositions(std::move(Arg.LoopDispositions)),
12739 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
12740 BlockDispositions(std::move(Arg.BlockDispositions)),
12741 SCEVUsers(std::move(Arg.SCEVUsers)),
12742 UnsignedRanges(std::move(Arg.UnsignedRanges)),
12743 SignedRanges(std::move(Arg.SignedRanges)),
12744 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
12745 UniquePreds(std::move(Arg.UniquePreds)),
12746 SCEVAllocator(std::move(Arg.SCEVAllocator)),
12747 LoopUsers(std::move(Arg.LoopUsers)),
12748 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
12749 FirstUnknown(Arg.FirstUnknown) {
12750 Arg.FirstUnknown = nullptr;
12753 ScalarEvolution::~ScalarEvolution() {
12754 // Iterate through all the SCEVUnknown instances and call their
12755 // destructors, so that they release their references to their values.
12756 for (SCEVUnknown *U = FirstUnknown; U;) {
12757 SCEVUnknown *Tmp = U;
12759 Tmp->~SCEVUnknown();
12761 FirstUnknown = nullptr;
12763 ExprValueMap.clear();
12764 ValueExprMap.clear();
12766 BackedgeTakenCounts.clear();
12767 PredicatedBackedgeTakenCounts.clear();
12769 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
12770 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
12771 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
12772 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
12773 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
12776 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
12777 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
12780 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
12782 // Print all inner loops first
12784 PrintLoopInfo(OS, SE, I);
12787 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12790 SmallVector<BasicBlock *, 8> ExitingBlocks;
12791 L->getExitingBlocks(ExitingBlocks);
12792 if (ExitingBlocks.size() != 1)
12793 OS << "<multiple exits> ";
12795 if (SE->hasLoopInvariantBackedgeTakenCount(L))
12796 OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n";
12798 OS << "Unpredictable backedge-taken count.\n";
12800 if (ExitingBlocks.size() > 1)
12801 for (BasicBlock *ExitingBlock : ExitingBlocks) {
12802 OS << " exit count for " << ExitingBlock->getName() << ": "
12803 << *SE->getExitCount(L, ExitingBlock) << "\n";
12807 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12810 if (!isa<SCEVCouldNotCompute>(SE->getConstantMaxBackedgeTakenCount(L))) {
12811 OS << "max backedge-taken count is " << *SE->getConstantMaxBackedgeTakenCount(L);
12812 if (SE->isBackedgeTakenCountMaxOrZero(L))
12813 OS << ", actual taken count either this or zero.";
12815 OS << "Unpredictable max backedge-taken count. ";
12820 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12823 SCEVUnionPredicate Pred;
12824 auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred);
12825 if (!isa<SCEVCouldNotCompute>(PBT)) {
12826 OS << "Predicated backedge-taken count is " << *PBT << "\n";
12827 OS << " Predicates:\n";
12830 OS << "Unpredictable predicated backedge-taken count. ";
12834 if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
12836 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12838 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
12842 static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
12844 case ScalarEvolution::LoopVariant:
12846 case ScalarEvolution::LoopInvariant:
12847 return "Invariant";
12848 case ScalarEvolution::LoopComputable:
12849 return "Computable";
12851 llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!");
12854 void ScalarEvolution::print(raw_ostream &OS) const {
12855 // ScalarEvolution's implementation of the print method is to print
12856 // out SCEV values of all instructions that are interesting. Doing
12857 // this potentially causes it to create new SCEV objects though,
12858 // which technically conflicts with the const qualifier. This isn't
12859 // observable from outside the class though, so casting away the
12860 // const isn't dangerous.
12861 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
12863 if (ClassifyExpressions) {
12864 OS << "Classifying expressions for: ";
12865 F.printAsOperand(OS, /*PrintType=*/false);
12867 for (Instruction &I : instructions(F))
12868 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
12871 const SCEV *SV = SE.getSCEV(&I);
12873 if (!isa<SCEVCouldNotCompute>(SV)) {
12875 SE.getUnsignedRange(SV).print(OS);
12877 SE.getSignedRange(SV).print(OS);
12880 const Loop *L = LI.getLoopFor(I.getParent());
12882 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
12886 if (!isa<SCEVCouldNotCompute>(AtUse)) {
12888 SE.getUnsignedRange(AtUse).print(OS);
12890 SE.getSignedRange(AtUse).print(OS);
12895 OS << "\t\t" "Exits: ";
12896 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
12897 if (!SE.isLoopInvariant(ExitValue, L)) {
12898 OS << "<<Unknown>>";
12904 for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
12906 OS << "\t\t" "LoopDispositions: { ";
12912 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12913 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
12916 for (auto *InnerL : depth_first(L)) {
12920 OS << "\t\t" "LoopDispositions: { ";
12926 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12927 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
12937 OS << "Determining loop execution counts for: ";
12938 F.printAsOperand(OS, /*PrintType=*/false);
12941 PrintLoopInfo(OS, &SE, I);
12944 ScalarEvolution::LoopDisposition
12945 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
12946 auto &Values = LoopDispositions[S];
12947 for (auto &V : Values) {
12948 if (V.getPointer() == L)
12951 Values.emplace_back(L, LoopVariant);
12952 LoopDisposition D = computeLoopDisposition(S, L);
12953 auto &Values2 = LoopDispositions[S];
12954 for (auto &V : llvm::reverse(Values2)) {
12955 if (V.getPointer() == L) {
12963 ScalarEvolution::LoopDisposition
12964 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
12965 switch (S->getSCEVType()) {
12967 return LoopInvariant;
12972 return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
12973 case scAddRecExpr: {
12974 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
12976 // If L is the addrec's loop, it's computable.
12977 if (AR->getLoop() == L)
12978 return LoopComputable;
12980 // Add recurrences are never invariant in the function-body (null loop).
12982 return LoopVariant;
12984 // Everything that is not defined at loop entry is variant.
12985 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
12986 return LoopVariant;
12987 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
12988 " dominate the contained loop's header?");
12990 // This recurrence is invariant w.r.t. L if AR's loop contains L.
12991 if (AR->getLoop()->contains(L))
12992 return LoopInvariant;
12994 // This recurrence is variant w.r.t. L if any of its operands
12996 for (auto *Op : AR->operands())
12997 if (!isLoopInvariant(Op, L))
12998 return LoopVariant;
13000 // Otherwise it's loop-invariant.
13001 return LoopInvariant;
13009 case scSequentialUMinExpr: {
13010 bool HasVarying = false;
13011 for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
13012 LoopDisposition D = getLoopDisposition(Op, L);
13013 if (D == LoopVariant)
13014 return LoopVariant;
13015 if (D == LoopComputable)
13018 return HasVarying ? LoopComputable : LoopInvariant;
13021 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
13022 LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
13023 if (LD == LoopVariant)
13024 return LoopVariant;
13025 LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
13026 if (RD == LoopVariant)
13027 return LoopVariant;
13028 return (LD == LoopInvariant && RD == LoopInvariant) ?
13029 LoopInvariant : LoopComputable;
13032 // All non-instruction values are loop invariant. All instructions are loop
13033 // invariant if they are not contained in the specified loop.
13034 // Instructions are never considered invariant in the function body
13035 // (null loop) because they are defined within the "loop".
13036 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13037 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13038 return LoopInvariant;
13039 case scCouldNotCompute:
13040 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13042 llvm_unreachable("Unknown SCEV kind!");
13045 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
13046 return getLoopDisposition(S, L) == LoopInvariant;
13049 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
13050 return getLoopDisposition(S, L) == LoopComputable;
13053 ScalarEvolution::BlockDisposition
13054 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13055 auto &Values = BlockDispositions[S];
13056 for (auto &V : Values) {
13057 if (V.getPointer() == BB)
13060 Values.emplace_back(BB, DoesNotDominateBlock);
13061 BlockDisposition D = computeBlockDisposition(S, BB);
13062 auto &Values2 = BlockDispositions[S];
13063 for (auto &V : llvm::reverse(Values2)) {
13064 if (V.getPointer() == BB) {
13072 ScalarEvolution::BlockDisposition
13073 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13074 switch (S->getSCEVType()) {
13076 return ProperlyDominatesBlock;
13081 return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
13082 case scAddRecExpr: {
13083 // This uses a "dominates" query instead of "properly dominates" query
13084 // to test for proper dominance too, because the instruction which
13085 // produces the addrec's value is a PHI, and a PHI effectively properly
13086 // dominates its entire containing block.
13087 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13088 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13089 return DoesNotDominateBlock;
13091 // Fall through into SCEVNAryExpr handling.
13100 case scSequentialUMinExpr: {
13101 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
13102 bool Proper = true;
13103 for (const SCEV *NAryOp : NAry->operands()) {
13104 BlockDisposition D = getBlockDisposition(NAryOp, BB);
13105 if (D == DoesNotDominateBlock)
13106 return DoesNotDominateBlock;
13107 if (D == DominatesBlock)
13110 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13113 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
13114 const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
13115 BlockDisposition LD = getBlockDisposition(LHS, BB);
13116 if (LD == DoesNotDominateBlock)
13117 return DoesNotDominateBlock;
13118 BlockDisposition RD = getBlockDisposition(RHS, BB);
13119 if (RD == DoesNotDominateBlock)
13120 return DoesNotDominateBlock;
13121 return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
13122 ProperlyDominatesBlock : DominatesBlock;
13125 if (Instruction *I =
13126 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13127 if (I->getParent() == BB)
13128 return DominatesBlock;
13129 if (DT.properlyDominates(I->getParent(), BB))
13130 return ProperlyDominatesBlock;
13131 return DoesNotDominateBlock;
13133 return ProperlyDominatesBlock;
13134 case scCouldNotCompute:
13135 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13137 llvm_unreachable("Unknown SCEV kind!");
13140 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
13141 return getBlockDisposition(S, BB) >= DominatesBlock;
13144 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
13145 return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
13148 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
13149 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
13152 void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
13155 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13156 auto It = BECounts.find(L);
13157 if (It != BECounts.end()) {
13158 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
13159 if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
13160 auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
13161 assert(UserIt != BECountUsers.end());
13162 UserIt->second.erase({L, Predicated});
13165 BECounts.erase(It);
13169 void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
13170 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
13171 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
13173 while (!Worklist.empty()) {
13174 const SCEV *Curr = Worklist.pop_back_val();
13175 auto Users = SCEVUsers.find(Curr);
13176 if (Users != SCEVUsers.end())
13177 for (auto *User : Users->second)
13178 if (ToForget.insert(User).second)
13179 Worklist.push_back(User);
13182 for (auto *S : ToForget)
13183 forgetMemoizedResultsImpl(S);
13185 for (auto I = PredicatedSCEVRewrites.begin();
13186 I != PredicatedSCEVRewrites.end();) {
13187 std::pair<const SCEV *, const Loop *> Entry = I->first;
13188 if (ToForget.count(Entry.first))
13189 PredicatedSCEVRewrites.erase(I++);
13195 void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
13196 LoopDispositions.erase(S);
13197 BlockDispositions.erase(S);
13198 UnsignedRanges.erase(S);
13199 SignedRanges.erase(S);
13200 HasRecMap.erase(S);
13201 MinTrailingZerosCache.erase(S);
13203 auto ExprIt = ExprValueMap.find(S);
13204 if (ExprIt != ExprValueMap.end()) {
13205 for (auto &ValueAndOffset : ExprIt->second) {
13206 if (ValueAndOffset.second == nullptr) {
13207 auto ValueIt = ValueExprMap.find_as(ValueAndOffset.first);
13208 if (ValueIt != ValueExprMap.end())
13209 ValueExprMap.erase(ValueIt);
13212 ExprValueMap.erase(ExprIt);
13215 auto ScopeIt = ValuesAtScopes.find(S);
13216 if (ScopeIt != ValuesAtScopes.end()) {
13217 for (const auto &Pair : ScopeIt->second)
13218 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
13219 erase_value(ValuesAtScopesUsers[Pair.second],
13220 std::make_pair(Pair.first, S));
13221 ValuesAtScopes.erase(ScopeIt);
13224 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
13225 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
13226 for (const auto &Pair : ScopeUserIt->second)
13227 erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
13228 ValuesAtScopesUsers.erase(ScopeUserIt);
13231 auto BEUsersIt = BECountUsers.find(S);
13232 if (BEUsersIt != BECountUsers.end()) {
13233 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
13234 auto Copy = BEUsersIt->second;
13235 for (const auto &Pair : Copy)
13236 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
13237 BECountUsers.erase(BEUsersIt);
13242 ScalarEvolution::getUsedLoops(const SCEV *S,
13243 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
13244 struct FindUsedLoops {
13245 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
13246 : LoopsUsed(LoopsUsed) {}
13247 SmallPtrSetImpl<const Loop *> &LoopsUsed;
13248 bool follow(const SCEV *S) {
13249 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
13250 LoopsUsed.insert(AR->getLoop());
13254 bool isDone() const { return false; }
13257 FindUsedLoops F(LoopsUsed);
13258 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
13261 void ScalarEvolution::verify() const {
13262 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13263 ScalarEvolution SE2(F, TLI, AC, DT, LI);
13265 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
13267 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
13268 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
13269 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
13271 const SCEV *visitConstant(const SCEVConstant *Constant) {
13272 return SE.getConstant(Constant->getAPInt());
13275 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
13276 return SE.getUnknown(Expr->getValue());
13279 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
13280 return SE.getCouldNotCompute();
13284 SCEVMapper SCM(SE2);
13286 while (!LoopStack.empty()) {
13287 auto *L = LoopStack.pop_back_val();
13288 llvm::append_range(LoopStack, *L);
13290 auto *CurBECount = SCM.visit(
13291 const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L));
13292 auto *NewBECount = SE2.getBackedgeTakenCount(L);
13294 if (CurBECount == SE2.getCouldNotCompute() ||
13295 NewBECount == SE2.getCouldNotCompute()) {
13296 // NB! This situation is legal, but is very suspicious -- whatever pass
13297 // change the loop to make a trip count go from could not compute to
13298 // computable or vice-versa *should have* invalidated SCEV. However, we
13299 // choose not to assert here (for now) since we don't want false
13304 if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) {
13305 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
13306 // not propagate undef aggressively). This means we can (and do) fail
13307 // verification in cases where a transform makes the trip count of a loop
13308 // go from "undef" to "undef+1" (say). The transform is fine, since in
13309 // both cases the loop iterates "undef" times, but SCEV thinks we
13310 // increased the trip count of the loop by 1 incorrectly.
13314 if (SE.getTypeSizeInBits(CurBECount->getType()) >
13315 SE.getTypeSizeInBits(NewBECount->getType()))
13316 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
13317 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
13318 SE.getTypeSizeInBits(NewBECount->getType()))
13319 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
13321 const SCEV *Delta = SE2.getMinusSCEV(CurBECount, NewBECount);
13323 // Unless VerifySCEVStrict is set, we only compare constant deltas.
13324 if ((VerifySCEVStrict || isa<SCEVConstant>(Delta)) && !Delta->isZero()) {
13325 dbgs() << "Trip Count for " << *L << " Changed!\n";
13326 dbgs() << "Old: " << *CurBECount << "\n";
13327 dbgs() << "New: " << *NewBECount << "\n";
13328 dbgs() << "Delta: " << *Delta << "\n";
13333 // Collect all valid loops currently in LoopInfo.
13334 SmallPtrSet<Loop *, 32> ValidLoops;
13335 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
13336 while (!Worklist.empty()) {
13337 Loop *L = Worklist.pop_back_val();
13338 if (ValidLoops.contains(L))
13340 ValidLoops.insert(L);
13341 Worklist.append(L->begin(), L->end());
13343 for (auto &KV : ValueExprMap) {
13345 // Check for SCEV expressions referencing invalid/deleted loops.
13346 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
13347 assert(ValidLoops.contains(AR->getLoop()) &&
13348 "AddRec references invalid loop");
13352 // Check that the value is also part of the reverse map.
13353 auto It = ExprValueMap.find(KV.second);
13354 if (It == ExprValueMap.end() || !It->second.contains({KV.first, nullptr})) {
13355 dbgs() << "Value " << *KV.first
13356 << " is in ValueExprMap but not in ExprValueMap\n";
13361 for (const auto &KV : ExprValueMap) {
13362 for (const auto &ValueAndOffset : KV.second) {
13363 if (ValueAndOffset.second != nullptr)
13366 auto It = ValueExprMap.find_as(ValueAndOffset.first);
13367 if (It == ValueExprMap.end()) {
13368 dbgs() << "Value " << *ValueAndOffset.first
13369 << " is in ExprValueMap but not in ValueExprMap\n";
13372 if (It->second != KV.first) {
13373 dbgs() << "Value " << *ValueAndOffset.first
13374 << " mapped to " << *It->second
13375 << " rather than " << *KV.first << "\n";
13381 // Verify integrity of SCEV users.
13382 for (const auto &S : UniqueSCEVs) {
13383 SmallVector<const SCEV *, 4> Ops;
13384 collectUniqueOps(&S, Ops);
13385 for (const auto *Op : Ops) {
13386 // We do not store dependencies of constants.
13387 if (isa<SCEVConstant>(Op))
13389 auto It = SCEVUsers.find(Op);
13390 if (It != SCEVUsers.end() && It->second.count(&S))
13392 dbgs() << "Use of operand " << *Op << " by user " << S
13393 << " is not being tracked!\n";
13398 // Verify integrity of ValuesAtScopes users.
13399 for (const auto &ValueAndVec : ValuesAtScopes) {
13400 const SCEV *Value = ValueAndVec.first;
13401 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
13402 const Loop *L = LoopAndValueAtScope.first;
13403 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
13404 if (!isa<SCEVConstant>(ValueAtScope)) {
13405 auto It = ValuesAtScopesUsers.find(ValueAtScope);
13406 if (It != ValuesAtScopesUsers.end() &&
13407 is_contained(It->second, std::make_pair(L, Value)))
13409 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
13410 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
13416 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
13417 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
13418 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
13419 const Loop *L = LoopAndValue.first;
13420 const SCEV *Value = LoopAndValue.second;
13421 assert(!isa<SCEVConstant>(Value));
13422 auto It = ValuesAtScopes.find(Value);
13423 if (It != ValuesAtScopes.end() &&
13424 is_contained(It->second, std::make_pair(L, ValueAtScope)))
13426 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
13427 << *ValueAtScope << " missing in ValuesAtScopes\n";
13432 // Verify integrity of BECountUsers.
13433 auto VerifyBECountUsers = [&](bool Predicated) {
13435 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13436 for (const auto &LoopAndBEInfo : BECounts) {
13437 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
13438 if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
13439 auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
13440 if (UserIt != BECountUsers.end() &&
13441 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
13443 dbgs() << "Value " << *ENT.ExactNotTaken << " for loop "
13444 << *LoopAndBEInfo.first << " missing from BECountUsers\n";
13450 VerifyBECountUsers(/* Predicated */ false);
13451 VerifyBECountUsers(/* Predicated */ true);
13454 bool ScalarEvolution::invalidate(
13455 Function &F, const PreservedAnalyses &PA,
13456 FunctionAnalysisManager::Invalidator &Inv) {
13457 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
13458 // of its dependencies is invalidated.
13459 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
13460 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
13461 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
13462 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
13463 Inv.invalidate<LoopAnalysis>(F, PA);
13466 AnalysisKey ScalarEvolutionAnalysis::Key;
13468 ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
13469 FunctionAnalysisManager &AM) {
13470 return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
13471 AM.getResult<AssumptionAnalysis>(F),
13472 AM.getResult<DominatorTreeAnalysis>(F),
13473 AM.getResult<LoopAnalysis>(F));
13477 ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) {
13478 AM.getResult<ScalarEvolutionAnalysis>(F).verify();
13479 return PreservedAnalyses::all();
13483 ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
13484 // For compatibility with opt's -analyze feature under legacy pass manager
13485 // which was not ported to NPM. This keeps tests using
13486 // update_analyze_test_checks.py working.
13487 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
13488 << F.getName() << "':\n";
13489 AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
13490 return PreservedAnalyses::all();
13493 INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",
13494 "Scalar Evolution Analysis", false, true)
13495 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
13496 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
13497 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
13498 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
13499 INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
13500 "Scalar Evolution Analysis", false, true)
13502 char ScalarEvolutionWrapperPass::ID = 0;
13504 ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
13505 initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
13508 bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
13509 SE.reset(new ScalarEvolution(
13510 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
13511 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
13512 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
13513 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
13517 void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
13519 void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
13523 void ScalarEvolutionWrapperPass::verifyAnalysis() const {
13530 void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
13531 AU.setPreservesAll();
13532 AU.addRequiredTransitive<AssumptionCacheTracker>();
13533 AU.addRequiredTransitive<LoopInfoWrapperPass>();
13534 AU.addRequiredTransitive<DominatorTreeWrapperPass>();
13535 AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
13538 const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
13540 FoldingSetNodeID ID;
13541 assert(LHS->getType() == RHS->getType() &&
13542 "Type mismatch between LHS and RHS");
13543 // Unique this node based on the arguments
13544 ID.AddInteger(SCEVPredicate::P_Equal);
13545 ID.AddPointer(LHS);
13546 ID.AddPointer(RHS);
13547 void *IP = nullptr;
13548 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
13550 SCEVEqualPredicate *Eq = new (SCEVAllocator)
13551 SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
13552 UniquePreds.InsertNode(Eq, IP);
13556 const SCEVPredicate *ScalarEvolution::getWrapPredicate(
13557 const SCEVAddRecExpr *AR,
13558 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
13559 FoldingSetNodeID ID;
13560 // Unique this node based on the arguments
13561 ID.AddInteger(SCEVPredicate::P_Wrap);
13563 ID.AddInteger(AddedFlags);
13564 void *IP = nullptr;
13565 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
13567 auto *OF = new (SCEVAllocator)
13568 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
13569 UniquePreds.InsertNode(OF, IP);
13575 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
13578 /// Rewrites \p S in the context of a loop L and the SCEV predication
13579 /// infrastructure.
13581 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
13582 /// equivalences present in \p Pred.
13584 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
13585 /// \p NewPreds such that the result will be an AddRecExpr.
13586 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
13587 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
13588 SCEVUnionPredicate *Pred) {
13589 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
13590 return Rewriter.visit(S);
13593 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
13595 auto ExprPreds = Pred->getPredicatesForExpr(Expr);
13596 for (auto *Pred : ExprPreds)
13597 if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
13598 if (IPred->getLHS() == Expr)
13599 return IPred->getRHS();
13601 return convertToAddRecWithPreds(Expr);
13604 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
13605 const SCEV *Operand = visit(Expr->getOperand());
13606 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
13607 if (AR && AR->getLoop() == L && AR->isAffine()) {
13608 // This couldn't be folded because the operand didn't have the nuw
13609 // flag. Add the nusw flag as an assumption that we could make.
13610 const SCEV *Step = AR->getStepRecurrence(SE);
13611 Type *Ty = Expr->getType();
13612 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
13613 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
13614 SE.getSignExtendExpr(Step, Ty), L,
13615 AR->getNoWrapFlags());
13617 return SE.getZeroExtendExpr(Operand, Expr->getType());
13620 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
13621 const SCEV *Operand = visit(Expr->getOperand());
13622 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
13623 if (AR && AR->getLoop() == L && AR->isAffine()) {
13624 // This couldn't be folded because the operand didn't have the nsw
13625 // flag. Add the nssw flag as an assumption that we could make.
13626 const SCEV *Step = AR->getStepRecurrence(SE);
13627 Type *Ty = Expr->getType();
13628 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
13629 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
13630 SE.getSignExtendExpr(Step, Ty), L,
13631 AR->getNoWrapFlags());
13633 return SE.getSignExtendExpr(Operand, Expr->getType());
13637 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
13638 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
13639 SCEVUnionPredicate *Pred)
13640 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
13642 bool addOverflowAssumption(const SCEVPredicate *P) {
13644 // Check if we've already made this assumption.
13645 return Pred && Pred->implies(P);
13647 NewPreds->insert(P);
13651 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
13652 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
13653 auto *A = SE.getWrapPredicate(AR, AddedFlags);
13654 return addOverflowAssumption(A);
13657 // If \p Expr represents a PHINode, we try to see if it can be represented
13658 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
13659 // to add this predicate as a runtime overflow check, we return the AddRec.
13660 // If \p Expr does not meet these conditions (is not a PHI node, or we
13661 // couldn't create an AddRec for it, or couldn't add the predicate), we just
13663 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
13664 if (!isa<PHINode>(Expr->getValue()))
13666 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
13667 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
13668 if (!PredicatedRewrite)
13670 for (auto *P : PredicatedRewrite->second){
13671 // Wrap predicates from outer loops are not supported.
13672 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
13673 auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr());
13674 if (L != AR->getLoop())
13677 if (!addOverflowAssumption(P))
13680 return PredicatedRewrite->first;
13683 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
13684 SCEVUnionPredicate *Pred;
13688 } // end anonymous namespace
13690 const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
13691 SCEVUnionPredicate &Preds) {
13692 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
13695 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
13696 const SCEV *S, const Loop *L,
13697 SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
13698 SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
13699 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
13700 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
13705 // Since the transformation was successful, we can now transfer the SCEV
13707 for (auto *P : TransformPreds)
13713 /// SCEV predicates
13714 SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
13715 SCEVPredicateKind Kind)
13716 : FastID(ID), Kind(Kind) {}
13718 SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
13719 const SCEV *LHS, const SCEV *RHS)
13720 : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {
13721 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
13722 assert(LHS != RHS && "LHS and RHS are the same SCEV");
13725 bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
13726 const auto *Op = dyn_cast<SCEVEqualPredicate>(N);
13731 return Op->LHS == LHS && Op->RHS == RHS;
13734 bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
13736 const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
13738 void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
13739 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
13742 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
13743 const SCEVAddRecExpr *AR,
13744 IncrementWrapFlags Flags)
13745 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
13747 const SCEV *SCEVWrapPredicate::getExpr() const { return AR; }
13749 bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
13750 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
13752 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
13755 bool SCEVWrapPredicate::isAlwaysTrue() const {
13756 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
13757 IncrementWrapFlags IFlags = Flags;
13759 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
13760 IFlags = clearFlags(IFlags, IncrementNSSW);
13762 return IFlags == IncrementAnyWrap;
13765 void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
13766 OS.indent(Depth) << *getExpr() << " Added Flags: ";
13767 if (SCEVWrapPredicate::IncrementNUSW & getFlags())
13769 if (SCEVWrapPredicate::IncrementNSSW & getFlags())
13774 SCEVWrapPredicate::IncrementWrapFlags
13775 SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
13776 ScalarEvolution &SE) {
13777 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
13778 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
13780 // We can safely transfer the NSW flag as NSSW.
13781 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
13782 ImpliedFlags = IncrementNSSW;
13784 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
13785 // If the increment is positive, the SCEV NUW flag will also imply the
13786 // WrapPredicate NUSW flag.
13787 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
13788 if (Step->getValue()->getValue().isNonNegative())
13789 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
13792 return ImpliedFlags;
13795 /// Union predicates don't get cached so create a dummy set ID for it.
13796 SCEVUnionPredicate::SCEVUnionPredicate()
13797 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
13799 bool SCEVUnionPredicate::isAlwaysTrue() const {
13800 return all_of(Preds,
13801 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
13804 ArrayRef<const SCEVPredicate *>
13805 SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
13806 auto I = SCEVToPreds.find(Expr);
13807 if (I == SCEVToPreds.end())
13808 return ArrayRef<const SCEVPredicate *>();
13812 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
13813 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
13814 return all_of(Set->Preds,
13815 [this](const SCEVPredicate *I) { return this->implies(I); });
13817 auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
13818 if (ScevPredsIt == SCEVToPreds.end())
13820 auto &SCEVPreds = ScevPredsIt->second;
13822 return any_of(SCEVPreds,
13823 [N](const SCEVPredicate *I) { return I->implies(N); });
13826 const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
13828 void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
13829 for (auto Pred : Preds)
13830 Pred->print(OS, Depth);
13833 void SCEVUnionPredicate::add(const SCEVPredicate *N) {
13834 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
13835 for (auto Pred : Set->Preds)
13843 const SCEV *Key = N->getExpr();
13844 assert(Key && "Only SCEVUnionPredicate doesn't have an "
13845 " associated expression!");
13847 SCEVToPreds[Key].push_back(N);
13848 Preds.push_back(N);
13851 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
13855 void ScalarEvolution::registerUser(const SCEV *User,
13856 ArrayRef<const SCEV *> Ops) {
13857 for (auto *Op : Ops)
13858 // We do not expect that forgetting cached data for SCEVConstants will ever
13859 // open any prospects for sharpening or introduce any correctness issues,
13860 // so we don't bother storing their dependencies.
13861 if (!isa<SCEVConstant>(Op))
13862 SCEVUsers[Op].insert(User);
13865 const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
13866 const SCEV *Expr = SE.getSCEV(V);
13867 RewriteEntry &Entry = RewriteMap[Expr];
13869 // If we already have an entry and the version matches, return it.
13870 if (Entry.second && Generation == Entry.first)
13871 return Entry.second;
13873 // We found an entry but it's stale. Rewrite the stale entry
13874 // according to the current predicate.
13876 Expr = Entry.second;
13878 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
13879 Entry = {Generation, NewSCEV};
13884 const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
13885 if (!BackedgeCount) {
13886 SCEVUnionPredicate BackedgePred;
13887 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred);
13888 addPredicate(BackedgePred);
13890 return BackedgeCount;
13893 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
13894 if (Preds.implies(&Pred))
13897 updateGeneration();
13900 const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
13904 void PredicatedScalarEvolution::updateGeneration() {
13905 // If the generation number wrapped recompute everything.
13906 if (++Generation == 0) {
13907 for (auto &II : RewriteMap) {
13908 const SCEV *Rewritten = II.second.second;
13909 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
13914 void PredicatedScalarEvolution::setNoOverflow(
13915 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
13916 const SCEV *Expr = getSCEV(V);
13917 const auto *AR = cast<SCEVAddRecExpr>(Expr);
13919 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
13921 // Clear the statically implied flags.
13922 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
13923 addPredicate(*SE.getWrapPredicate(AR, Flags));
13925 auto II = FlagsMap.insert({V, Flags});
13927 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
13930 bool PredicatedScalarEvolution::hasNoOverflow(
13931 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
13932 const SCEV *Expr = getSCEV(V);
13933 const auto *AR = cast<SCEVAddRecExpr>(Expr);
13935 Flags = SCEVWrapPredicate::clearFlags(
13936 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
13938 auto II = FlagsMap.find(V);
13940 if (II != FlagsMap.end())
13941 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
13943 return Flags == SCEVWrapPredicate::IncrementAnyWrap;
13946 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
13947 const SCEV *Expr = this->getSCEV(V);
13948 SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
13949 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
13954 for (auto *P : NewPreds)
13957 updateGeneration();
13958 RewriteMap[SE.getSCEV(V)] = {Generation, New};
13962 PredicatedScalarEvolution::PredicatedScalarEvolution(
13963 const PredicatedScalarEvolution &Init)
13964 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds),
13965 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
13966 for (auto I : Init.FlagsMap)
13967 FlagsMap.insert(I);
13970 void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
13972 for (auto *BB : L.getBlocks())
13973 for (auto &I : *BB) {
13974 if (!SE.isSCEVable(I.getType()))
13977 auto *Expr = SE.getSCEV(&I);
13978 auto II = RewriteMap.find(Expr);
13980 if (II == RewriteMap.end())
13983 // Don't print things that are not interesting.
13984 if (II->second.second == Expr)
13987 OS.indent(Depth) << "[PSE]" << I << ":\n";
13988 OS.indent(Depth + 2) << *Expr << "\n";
13989 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
13993 // Match the mathematical pattern A - (A / B) * B, where A and B can be
13994 // arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
13995 // for URem with constant power-of-2 second operands.
13996 // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
13997 // 4, A / B becomes X / 8).
13998 bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
13999 const SCEV *&RHS) {
14000 // Try to match 'zext (trunc A to iB) to iY', which is used
14001 // for URem with constant power-of-2 second operands. Make sure the size of
14002 // the operand A matches the size of the whole expressions.
14003 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
14004 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
14005 LHS = Trunc->getOperand();
14006 // Bail out if the type of the LHS is larger than the type of the
14007 // expression for now.
14008 if (getTypeSizeInBits(LHS->getType()) >
14009 getTypeSizeInBits(Expr->getType()))
14011 if (LHS->getType() != Expr->getType())
14012 LHS = getZeroExtendExpr(LHS, Expr->getType());
14013 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
14014 << getTypeSizeInBits(Trunc->getType()));
14017 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
14018 if (Add == nullptr || Add->getNumOperands() != 2)
14021 const SCEV *A = Add->getOperand(1);
14022 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
14024 if (Mul == nullptr)
14027 const auto MatchURemWithDivisor = [&](const SCEV *B) {
14028 // (SomeExpr + (-(SomeExpr / B) * B)).
14029 if (Expr == getURemExpr(A, B)) {
14037 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
14038 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
14039 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14040 MatchURemWithDivisor(Mul->getOperand(2));
14042 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
14043 if (Mul->getNumOperands() == 2)
14044 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14045 MatchURemWithDivisor(Mul->getOperand(0)) ||
14046 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
14047 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
14052 ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
14053 SmallVector<BasicBlock*, 16> ExitingBlocks;
14054 L->getExitingBlocks(ExitingBlocks);
14056 // Form an expression for the maximum exit count possible for this loop. We
14057 // merge the max and exact information to approximate a version of
14058 // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
14059 SmallVector<const SCEV*, 4> ExitCounts;
14060 for (BasicBlock *ExitingBB : ExitingBlocks) {
14061 const SCEV *ExitCount = getExitCount(L, ExitingBB);
14062 if (isa<SCEVCouldNotCompute>(ExitCount))
14063 ExitCount = getExitCount(L, ExitingBB,
14064 ScalarEvolution::ConstantMaximum);
14065 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
14066 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
14067 "We should only have known counts for exiting blocks that "
14068 "dominate latch!");
14069 ExitCounts.push_back(ExitCount);
14072 if (ExitCounts.empty())
14073 return getCouldNotCompute();
14074 return getUMinFromMismatchedTypes(ExitCounts);
14077 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
14078 /// in the map. It skips AddRecExpr because we cannot guarantee that the
14079 /// replacement is loop invariant in the loop of the AddRec.
14081 /// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
14083 class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
14084 const DenseMap<const SCEV *, const SCEV *> ⤅
14087 SCEVLoopGuardRewriter(ScalarEvolution &SE,
14088 DenseMap<const SCEV *, const SCEV *> &M)
14089 : SCEVRewriteVisitor(SE), Map(M) {}
14091 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
14093 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14094 auto I = Map.find(Expr);
14095 if (I == Map.end())
14100 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14101 auto I = Map.find(Expr);
14102 if (I == Map.end())
14103 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
14109 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
14110 SmallVector<const SCEV *> ExprsToRewrite;
14111 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
14113 DenseMap<const SCEV *, const SCEV *>
14115 // WARNING: It is generally unsound to apply any wrap flags to the proposed
14116 // replacement SCEV which isn't directly implied by the structure of that
14117 // SCEV. In particular, using contextual facts to imply flags is *NOT*
14118 // legal. See the scoping rules for flags in the header to understand why.
14120 // If LHS is a constant, apply information to the other expression.
14121 if (isa<SCEVConstant>(LHS)) {
14122 std::swap(LHS, RHS);
14123 Predicate = CmpInst::getSwappedPredicate(Predicate);
14126 // Check for a condition of the form (-C1 + X < C2). InstCombine will
14127 // create this form when combining two checks of the form (X u< C2 + C1) and
14129 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
14130 &ExprsToRewrite]() {
14131 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
14132 if (!AddExpr || AddExpr->getNumOperands() != 2)
14135 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
14136 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
14137 auto *C2 = dyn_cast<SCEVConstant>(RHS);
14138 if (!C1 || !C2 || !LHSUnknown)
14142 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
14143 .sub(C1->getAPInt());
14145 // Bail out, unless we have a non-wrapping, monotonic range.
14146 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
14148 auto I = RewriteMap.find(LHSUnknown);
14149 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
14150 RewriteMap[LHSUnknown] = getUMaxExpr(
14151 getConstant(ExactRegion.getUnsignedMin()),
14152 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
14153 ExprsToRewrite.push_back(LHSUnknown);
14156 if (MatchRangeCheckIdiom())
14159 // If we have LHS == 0, check if LHS is computing a property of some unknown
14160 // SCEV %v which we can rewrite %v to express explicitly.
14161 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
14162 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
14163 RHSC->getValue()->isNullValue()) {
14164 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
14165 // explicitly express that.
14166 const SCEV *URemLHS = nullptr;
14167 const SCEV *URemRHS = nullptr;
14168 if (matchURem(LHS, URemLHS, URemRHS)) {
14169 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
14170 auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
14171 RewriteMap[LHSUnknown] = Multiple;
14172 ExprsToRewrite.push_back(LHSUnknown);
14178 // Do not apply information for constants or if RHS contains an AddRec.
14179 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
14182 // If RHS is SCEVUnknown, make sure the information is applied to it.
14183 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
14184 std::swap(LHS, RHS);
14185 Predicate = CmpInst::getSwappedPredicate(Predicate);
14188 // Limit to expressions that can be rewritten.
14189 if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
14192 // Check whether LHS has already been rewritten. In that case we want to
14193 // chain further rewrites onto the already rewritten value.
14194 auto I = RewriteMap.find(LHS);
14195 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
14197 const SCEV *RewrittenRHS = nullptr;
14198 switch (Predicate) {
14199 case CmpInst::ICMP_ULT:
14201 getUMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
14203 case CmpInst::ICMP_SLT:
14205 getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
14207 case CmpInst::ICMP_ULE:
14208 RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
14210 case CmpInst::ICMP_SLE:
14211 RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
14213 case CmpInst::ICMP_UGT:
14215 getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
14217 case CmpInst::ICMP_SGT:
14219 getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
14221 case CmpInst::ICMP_UGE:
14222 RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
14224 case CmpInst::ICMP_SGE:
14225 RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
14227 case CmpInst::ICMP_EQ:
14228 if (isa<SCEVConstant>(RHS))
14229 RewrittenRHS = RHS;
14231 case CmpInst::ICMP_NE:
14232 if (isa<SCEVConstant>(RHS) &&
14233 cast<SCEVConstant>(RHS)->getValue()->isNullValue())
14234 RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
14240 if (RewrittenRHS) {
14241 RewriteMap[LHS] = RewrittenRHS;
14242 if (LHS == RewrittenLHS)
14243 ExprsToRewrite.push_back(LHS);
14246 // First, collect conditions from dominating branches. Starting at the loop
14247 // predecessor, climb up the predecessor chain, as long as there are
14248 // predecessors that can be found that have unique successors leading to the
14249 // original header.
14250 // TODO: share this logic with isLoopEntryGuardedByCond.
14251 SmallVector<std::pair<Value *, bool>> Terms;
14252 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
14253 L->getLoopPredecessor(), L->getHeader());
14254 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
14256 const BranchInst *LoopEntryPredicate =
14257 dyn_cast<BranchInst>(Pair.first->getTerminator());
14258 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
14261 Terms.emplace_back(LoopEntryPredicate->getCondition(),
14262 LoopEntryPredicate->getSuccessor(0) == Pair.second);
14265 // Now apply the information from the collected conditions to RewriteMap.
14266 // Conditions are processed in reverse order, so the earliest conditions is
14267 // processed first. This ensures the SCEVs with the shortest dependency chains
14268 // are constructed first.
14269 DenseMap<const SCEV *, const SCEV *> RewriteMap;
14270 for (auto &E : reverse(Terms)) {
14271 bool EnterIfTrue = E.second;
14272 SmallVector<Value *, 8> Worklist;
14273 SmallPtrSet<Value *, 8> Visited;
14274 Worklist.push_back(E.first);
14275 while (!Worklist.empty()) {
14276 Value *Cond = Worklist.pop_back_val();
14277 if (!Visited.insert(Cond).second)
14280 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14282 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
14283 CollectCondition(Predicate, getSCEV(Cmp->getOperand(0)),
14284 getSCEV(Cmp->getOperand(1)), RewriteMap);
14289 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
14290 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
14291 Worklist.push_back(L);
14292 Worklist.push_back(R);
14297 // Also collect information from assumptions dominating the loop.
14298 for (auto &AssumeVH : AC.assumptions()) {
14301 auto *AssumeI = cast<CallInst>(AssumeVH);
14302 auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0));
14303 if (!Cmp || !DT.dominates(AssumeI, L->getHeader()))
14305 CollectCondition(Cmp->getPredicate(), getSCEV(Cmp->getOperand(0)),
14306 getSCEV(Cmp->getOperand(1)), RewriteMap);
14309 if (RewriteMap.empty())
14312 // Now that all rewrite information is collect, rewrite the collected
14313 // expressions with the information in the map. This applies information to
14314 // sub-expressions.
14315 if (ExprsToRewrite.size() > 1) {
14316 for (const SCEV *Expr : ExprsToRewrite) {
14317 const SCEV *RewriteTo = RewriteMap[Expr];
14318 RewriteMap.erase(Expr);
14319 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
14320 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
14324 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
14325 return Rewriter.visit(Expr);