1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
37 //===----------------------------------------------------------------------===//
39 // There are several good references for the techniques used in this analysis.
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 // On computational properties of chains of recurrences
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
58 //===----------------------------------------------------------------------===//
60 #include "llvm/Analysis/ScalarEvolution.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
64 #include "llvm/ADT/DepthFirstIterator.h"
65 #include "llvm/ADT/EquivalenceClasses.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
78 #include "llvm/Analysis/AssumptionCache.h"
79 #include "llvm/Analysis/ConstantFolding.h"
80 #include "llvm/Analysis/InstructionSimplify.h"
81 #include "llvm/Analysis/LoopInfo.h"
82 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
83 #include "llvm/Analysis/TargetLibraryInfo.h"
84 #include "llvm/Analysis/ValueTracking.h"
85 #include "llvm/Config/llvm-config.h"
86 #include "llvm/IR/Argument.h"
87 #include "llvm/IR/BasicBlock.h"
88 #include "llvm/IR/CFG.h"
89 #include "llvm/IR/Constant.h"
90 #include "llvm/IR/ConstantRange.h"
91 #include "llvm/IR/Constants.h"
92 #include "llvm/IR/DataLayout.h"
93 #include "llvm/IR/DerivedTypes.h"
94 #include "llvm/IR/Dominators.h"
95 #include "llvm/IR/Function.h"
96 #include "llvm/IR/GlobalAlias.h"
97 #include "llvm/IR/GlobalValue.h"
98 #include "llvm/IR/InstIterator.h"
99 #include "llvm/IR/InstrTypes.h"
100 #include "llvm/IR/Instruction.h"
101 #include "llvm/IR/Instructions.h"
102 #include "llvm/IR/IntrinsicInst.h"
103 #include "llvm/IR/Intrinsics.h"
104 #include "llvm/IR/LLVMContext.h"
105 #include "llvm/IR/Operator.h"
106 #include "llvm/IR/PatternMatch.h"
107 #include "llvm/IR/Type.h"
108 #include "llvm/IR/Use.h"
109 #include "llvm/IR/User.h"
110 #include "llvm/IR/Value.h"
111 #include "llvm/IR/Verifier.h"
112 #include "llvm/InitializePasses.h"
113 #include "llvm/Pass.h"
114 #include "llvm/Support/Casting.h"
115 #include "llvm/Support/CommandLine.h"
116 #include "llvm/Support/Compiler.h"
117 #include "llvm/Support/Debug.h"
118 #include "llvm/Support/ErrorHandling.h"
119 #include "llvm/Support/KnownBits.h"
120 #include "llvm/Support/SaveAndRestore.h"
121 #include "llvm/Support/raw_ostream.h"
133 using namespace llvm;
134 using namespace PatternMatch;
136 #define DEBUG_TYPE "scalar-evolution"
138 STATISTIC(NumTripCountsComputed,
139 "Number of loops with predictable loop counts");
140 STATISTIC(NumTripCountsNotComputed,
141 "Number of loops without predictable loop counts");
142 STATISTIC(NumBruteForceTripCountsComputed,
143 "Number of loops with trip counts computed by force");
145 #ifdef EXPENSIVE_CHECKS
146 bool llvm::VerifySCEV = true;
148 bool llvm::VerifySCEV = false;
151 static cl::opt<unsigned>
152 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
153 cl::desc("Maximum number of iterations SCEV will "
154 "symbolically execute a constant "
158 static cl::opt<bool, true> VerifySCEVOpt(
159 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
160 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
161 static cl::opt<bool> VerifySCEVStrict(
162 "verify-scev-strict", cl::Hidden,
163 cl::desc("Enable stricter verification with -verify-scev is passed"));
165 VerifySCEVMap("verify-scev-maps", cl::Hidden,
166 cl::desc("Verify no dangling value in ScalarEvolution's "
167 "ExprValueMap (slow)"));
169 static cl::opt<bool> VerifyIR(
170 "scev-verify-ir", cl::Hidden,
171 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
174 static cl::opt<unsigned> MulOpsInlineThreshold(
175 "scev-mulops-inline-threshold", cl::Hidden,
176 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
179 static cl::opt<unsigned> AddOpsInlineThreshold(
180 "scev-addops-inline-threshold", cl::Hidden,
181 cl::desc("Threshold for inlining addition operands into a SCEV"),
184 static cl::opt<unsigned> MaxSCEVCompareDepth(
185 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
186 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
189 static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
190 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
191 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
194 static cl::opt<unsigned> MaxValueCompareDepth(
195 "scalar-evolution-max-value-compare-depth", cl::Hidden,
196 cl::desc("Maximum depth of recursive value complexity comparisons"),
199 static cl::opt<unsigned>
200 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
201 cl::desc("Maximum depth of recursive arithmetics"),
204 static cl::opt<unsigned> MaxConstantEvolvingDepth(
205 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
206 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
208 static cl::opt<unsigned>
209 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
210 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
213 static cl::opt<unsigned>
214 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
215 cl::desc("Max coefficients in AddRec during evolving"),
218 static cl::opt<unsigned>
219 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
220 cl::desc("Size of the expression which is considered huge"),
224 ClassifyExpressions("scalar-evolution-classify-expressions",
225 cl::Hidden, cl::init(true),
226 cl::desc("When printing analysis, include information on every instruction"));
228 static cl::opt<bool> UseExpensiveRangeSharpening(
229 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
231 cl::desc("Use more powerful methods of sharpening expression ranges. May "
232 "be costly in terms of compile time"));
234 static cl::opt<unsigned> MaxPhiSCCAnalysisSize(
235 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
236 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
237 "Phi strongly connected components"),
241 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
242 cl::desc("Handle <= and >= in finite loops"),
245 //===----------------------------------------------------------------------===//
246 // SCEV class definitions
247 //===----------------------------------------------------------------------===//
249 //===----------------------------------------------------------------------===//
250 // Implementation of the SCEV class.
253 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
254 LLVM_DUMP_METHOD void SCEV::dump() const {
260 void SCEV::print(raw_ostream &OS) const {
261 switch (getSCEVType()) {
263 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
266 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
267 const SCEV *Op = PtrToInt->getOperand();
268 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
269 << *PtrToInt->getType() << ")";
273 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
274 const SCEV *Op = Trunc->getOperand();
275 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
276 << *Trunc->getType() << ")";
280 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
281 const SCEV *Op = ZExt->getOperand();
282 OS << "(zext " << *Op->getType() << " " << *Op << " to "
283 << *ZExt->getType() << ")";
287 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
288 const SCEV *Op = SExt->getOperand();
289 OS << "(sext " << *Op->getType() << " " << *Op << " to "
290 << *SExt->getType() << ")";
294 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
295 OS << "{" << *AR->getOperand(0);
296 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
297 OS << ",+," << *AR->getOperand(i);
299 if (AR->hasNoUnsignedWrap())
301 if (AR->hasNoSignedWrap())
303 if (AR->hasNoSelfWrap() &&
304 !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
306 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
316 case scSequentialUMinExpr: {
317 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
318 const char *OpStr = nullptr;
319 switch (NAry->getSCEVType()) {
320 case scAddExpr: OpStr = " + "; break;
321 case scMulExpr: OpStr = " * "; break;
322 case scUMaxExpr: OpStr = " umax "; break;
323 case scSMaxExpr: OpStr = " smax "; break;
330 case scSequentialUMinExpr:
331 OpStr = " umin_seq ";
334 llvm_unreachable("There are no other nary expression types.");
337 ListSeparator LS(OpStr);
338 for (const SCEV *Op : NAry->operands())
341 switch (NAry->getSCEVType()) {
344 if (NAry->hasNoUnsignedWrap())
346 if (NAry->hasNoSignedWrap())
350 // Nothing to print for other nary expressions.
356 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
357 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
361 const SCEVUnknown *U = cast<SCEVUnknown>(this);
363 if (U->isSizeOf(AllocTy)) {
364 OS << "sizeof(" << *AllocTy << ")";
367 if (U->isAlignOf(AllocTy)) {
368 OS << "alignof(" << *AllocTy << ")";
374 if (U->isOffsetOf(CTy, FieldNo)) {
375 OS << "offsetof(" << *CTy << ", ";
376 FieldNo->printAsOperand(OS, false);
381 // Otherwise just print it normally.
382 U->getValue()->printAsOperand(OS, false);
385 case scCouldNotCompute:
386 OS << "***COULDNOTCOMPUTE***";
389 llvm_unreachable("Unknown SCEV kind!");
392 Type *SCEV::getType() const {
393 switch (getSCEVType()) {
395 return cast<SCEVConstant>(this)->getType();
400 return cast<SCEVCastExpr>(this)->getType();
402 return cast<SCEVAddRecExpr>(this)->getType();
404 return cast<SCEVMulExpr>(this)->getType();
409 return cast<SCEVMinMaxExpr>(this)->getType();
410 case scSequentialUMinExpr:
411 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
413 return cast<SCEVAddExpr>(this)->getType();
415 return cast<SCEVUDivExpr>(this)->getType();
417 return cast<SCEVUnknown>(this)->getType();
418 case scCouldNotCompute:
419 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
421 llvm_unreachable("Unknown SCEV kind!");
424 bool SCEV::isZero() const {
425 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
426 return SC->getValue()->isZero();
430 bool SCEV::isOne() const {
431 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
432 return SC->getValue()->isOne();
436 bool SCEV::isAllOnesValue() const {
437 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
438 return SC->getValue()->isMinusOne();
442 bool SCEV::isNonConstantNegative() const {
443 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
444 if (!Mul) return false;
446 // If there is a constant factor, it will be first.
447 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
448 if (!SC) return false;
450 // Return true if the value is negative, this matches things like (-42 * V).
451 return SC->getAPInt().isNegative();
454 SCEVCouldNotCompute::SCEVCouldNotCompute() :
455 SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
457 bool SCEVCouldNotCompute::classof(const SCEV *S) {
458 return S->getSCEVType() == scCouldNotCompute;
461 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
463 ID.AddInteger(scConstant);
466 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
467 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
468 UniqueSCEVs.InsertNode(S, IP);
472 const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
473 return getConstant(ConstantInt::get(getContext(), Val));
477 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
478 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
479 return getConstant(ConstantInt::get(ITy, V, isSigned));
482 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
483 const SCEV *op, Type *ty)
484 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
488 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
490 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
491 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
492 "Must be a non-bit-width-changing pointer-to-integer cast!");
495 SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
496 SCEVTypes SCEVTy, const SCEV *op,
498 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
500 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
502 : SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
503 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
504 "Cannot truncate non-integer value!");
507 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
508 const SCEV *op, Type *ty)
509 : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
510 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
511 "Cannot zero extend non-integer value!");
514 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
515 const SCEV *op, Type *ty)
516 : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
517 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
518 "Cannot sign extend non-integer value!");
521 void SCEVUnknown::deleted() {
522 // Clear this SCEVUnknown from various maps.
523 SE->forgetMemoizedResults(this);
525 // Remove this SCEVUnknown from the uniquing map.
526 SE->UniqueSCEVs.RemoveNode(this);
528 // Release the value.
532 void SCEVUnknown::allUsesReplacedWith(Value *New) {
533 // Clear this SCEVUnknown from various maps.
534 SE->forgetMemoizedResults(this);
536 // Remove this SCEVUnknown from the uniquing map.
537 SE->UniqueSCEVs.RemoveNode(this);
539 // Replace the value pointer in case someone is still using this SCEVUnknown.
543 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
544 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
545 if (VCE->getOpcode() == Instruction::PtrToInt)
546 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
547 if (CE->getOpcode() == Instruction::GetElementPtr &&
548 CE->getOperand(0)->isNullValue() &&
549 CE->getNumOperands() == 2)
550 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
552 AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
559 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
560 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
561 if (VCE->getOpcode() == Instruction::PtrToInt)
562 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
563 if (CE->getOpcode() == Instruction::GetElementPtr &&
564 CE->getOperand(0)->isNullValue()) {
565 Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
566 if (StructType *STy = dyn_cast<StructType>(Ty))
567 if (!STy->isPacked() &&
568 CE->getNumOperands() == 3 &&
569 CE->getOperand(1)->isNullValue()) {
570 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
572 STy->getNumElements() == 2 &&
573 STy->getElementType(0)->isIntegerTy(1)) {
574 AllocTy = STy->getElementType(1);
583 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
584 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
585 if (VCE->getOpcode() == Instruction::PtrToInt)
586 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
587 if (CE->getOpcode() == Instruction::GetElementPtr &&
588 CE->getNumOperands() == 3 &&
589 CE->getOperand(0)->isNullValue() &&
590 CE->getOperand(1)->isNullValue()) {
591 Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
592 // Ignore vector types here so that ScalarEvolutionExpander doesn't
593 // emit getelementptrs that index into vectors.
594 if (Ty->isStructTy() || Ty->isArrayTy()) {
596 FieldNo = CE->getOperand(2);
604 //===----------------------------------------------------------------------===//
606 //===----------------------------------------------------------------------===//
608 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
609 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
610 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
611 /// have been previously deemed to be "equally complex" by this routine. It is
612 /// intended to avoid exponential time complexity in cases like:
622 /// CompareValueComplexity(%f, %c)
624 /// Since we do not continue running this routine on expression trees once we
625 /// have seen unequal values, there is no need to track them in the cache.
627 CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
628 const LoopInfo *const LI, Value *LV, Value *RV,
630 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
633 // Order pointer values after integer values. This helps SCEVExpander form
635 bool LIsPointer = LV->getType()->isPointerTy(),
636 RIsPointer = RV->getType()->isPointerTy();
637 if (LIsPointer != RIsPointer)
638 return (int)LIsPointer - (int)RIsPointer;
640 // Compare getValueID values.
641 unsigned LID = LV->getValueID(), RID = RV->getValueID();
643 return (int)LID - (int)RID;
645 // Sort arguments by their position.
646 if (const auto *LA = dyn_cast<Argument>(LV)) {
647 const auto *RA = cast<Argument>(RV);
648 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
649 return (int)LArgNo - (int)RArgNo;
652 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
653 const auto *RGV = cast<GlobalValue>(RV);
655 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
656 auto LT = GV->getLinkage();
657 return !(GlobalValue::isPrivateLinkage(LT) ||
658 GlobalValue::isInternalLinkage(LT));
661 // Use the names to distinguish the two values, but only if the
662 // names are semantically important.
663 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
664 return LGV->getName().compare(RGV->getName());
667 // For instructions, compare their loop depth, and their operand count. This
669 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
670 const auto *RInst = cast<Instruction>(RV);
672 // Compare loop depths.
673 const BasicBlock *LParent = LInst->getParent(),
674 *RParent = RInst->getParent();
675 if (LParent != RParent) {
676 unsigned LDepth = LI->getLoopDepth(LParent),
677 RDepth = LI->getLoopDepth(RParent);
678 if (LDepth != RDepth)
679 return (int)LDepth - (int)RDepth;
682 // Compare the number of operands.
683 unsigned LNumOps = LInst->getNumOperands(),
684 RNumOps = RInst->getNumOperands();
685 if (LNumOps != RNumOps)
686 return (int)LNumOps - (int)RNumOps;
688 for (unsigned Idx : seq(0u, LNumOps)) {
690 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
691 RInst->getOperand(Idx), Depth + 1);
697 EqCacheValue.unionSets(LV, RV);
701 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
702 // than RHS, respectively. A three-way result allows recursive comparisons to be
704 // If the max analysis depth was reached, return None, assuming we do not know
705 // if they are equivalent for sure.
707 CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
708 EquivalenceClasses<const Value *> &EqCacheValue,
709 const LoopInfo *const LI, const SCEV *LHS,
710 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
711 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
715 // Primarily, sort the SCEVs by their getSCEVType().
716 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
718 return (int)LType - (int)RType;
720 if (EqCacheSCEV.isEquivalent(LHS, RHS))
723 if (Depth > MaxSCEVCompareDepth)
726 // Aside from the getSCEVType() ordering, the particular ordering
727 // isn't very important except that it's beneficial to be consistent,
728 // so that (a + b) and (b + a) don't end up as different expressions.
731 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
732 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
734 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
735 RU->getValue(), Depth + 1);
737 EqCacheSCEV.unionSets(LHS, RHS);
742 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
743 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
745 // Compare constant values.
746 const APInt &LA = LC->getAPInt();
747 const APInt &RA = RC->getAPInt();
748 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
749 if (LBitWidth != RBitWidth)
750 return (int)LBitWidth - (int)RBitWidth;
751 return LA.ult(RA) ? -1 : 1;
755 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
756 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
758 // There is always a dominance between two recs that are used by one SCEV,
759 // so we can safely sort recs by loop header dominance. We require such
760 // order in getAddExpr.
761 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
762 if (LLoop != RLoop) {
763 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
764 assert(LHead != RHead && "Two loops share the same header?");
765 if (DT.dominates(LHead, RHead))
768 assert(DT.dominates(RHead, LHead) &&
769 "No dominance between recurrences used by one SCEV?");
773 // Addrec complexity grows with operand count.
774 unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
775 if (LNumOps != RNumOps)
776 return (int)LNumOps - (int)RNumOps;
778 // Lexicographically compare.
779 for (unsigned i = 0; i != LNumOps; ++i) {
780 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
781 LA->getOperand(i), RA->getOperand(i), DT,
786 EqCacheSCEV.unionSets(LHS, RHS);
796 case scSequentialUMinExpr: {
797 const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
798 const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
800 // Lexicographically compare n-ary expressions.
801 unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
802 if (LNumOps != RNumOps)
803 return (int)LNumOps - (int)RNumOps;
805 for (unsigned i = 0; i != LNumOps; ++i) {
806 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
807 LC->getOperand(i), RC->getOperand(i), DT,
812 EqCacheSCEV.unionSets(LHS, RHS);
817 const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
818 const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
820 // Lexicographically compare udiv expressions.
821 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
822 RC->getLHS(), DT, Depth + 1);
825 X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
826 RC->getRHS(), DT, Depth + 1);
828 EqCacheSCEV.unionSets(LHS, RHS);
836 const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
837 const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
839 // Compare cast expressions by operand.
841 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
842 RC->getOperand(), DT, Depth + 1);
844 EqCacheSCEV.unionSets(LHS, RHS);
848 case scCouldNotCompute:
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
851 llvm_unreachable("Unknown SCEV kind!");
854 /// Given a list of SCEV objects, order them by their complexity, and group
855 /// objects of the same complexity together by value. When this routine is
856 /// finished, we know that any duplicates in the vector are consecutive and that
857 /// complexity is monotonically increasing.
859 /// Note that we go take special precautions to ensure that we get deterministic
860 /// results from this routine. In other words, we don't want the results of
861 /// this to depend on where the addresses of various SCEV objects happened to
863 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
864 LoopInfo *LI, DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
867 EquivalenceClasses<const SCEV *> EqCacheSCEV;
868 EquivalenceClasses<const Value *> EqCacheValue;
870 // Whether LHS has provably less complexity than RHS.
871 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
873 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
874 return Complexity && *Complexity < 0;
876 if (Ops.size() == 2) {
877 // This is the common case, which also happens to be trivially simple.
879 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
880 if (IsLessComplex(RHS, LHS))
885 // Do the rough sort by complexity.
886 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
887 return IsLessComplex(LHS, RHS);
890 // Now that we are sorted by complexity, group elements of the same
891 // complexity. Note that this is, at worst, N^2, but the vector is likely to
892 // be extremely short in practice. Note that we take this approach because we
893 // do not want to depend on the addresses of the objects we are grouping.
894 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
895 const SCEV *S = Ops[i];
896 unsigned Complexity = S->getSCEVType();
898 // If there are any objects of the same complexity and same value as this
900 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
901 if (Ops[j] == S) { // Found a duplicate.
902 // Move it to immediately after i'th element.
903 std::swap(Ops[i+1], Ops[j]);
904 ++i; // no need to rescan it.
905 if (i == e-2) return; // Done!
911 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
912 /// least HugeExprThreshold nodes).
913 static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
914 return any_of(Ops, [](const SCEV *S) {
915 return S->getExpressionSize() >= HugeExprThreshold;
919 //===----------------------------------------------------------------------===//
920 // Simple SCEV method implementations
921 //===----------------------------------------------------------------------===//
923 /// Compute BC(It, K). The result has width W. Assume, K > 0.
924 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
927 // Handle the simplest case efficiently.
929 return SE.getTruncateOrZeroExtend(It, ResultTy);
931 // We are using the following formula for BC(It, K):
933 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
935 // Suppose, W is the bitwidth of the return value. We must be prepared for
936 // overflow. Hence, we must assure that the result of our computation is
937 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
938 // safe in modular arithmetic.
940 // However, this code doesn't use exactly that formula; the formula it uses
941 // is something like the following, where T is the number of factors of 2 in
942 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
945 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
947 // This formula is trivially equivalent to the previous formula. However,
948 // this formula can be implemented much more efficiently. The trick is that
949 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
950 // arithmetic. To do exact division in modular arithmetic, all we have
951 // to do is multiply by the inverse. Therefore, this step can be done at
954 // The next issue is how to safely do the division by 2^T. The way this
955 // is done is by doing the multiplication step at a width of at least W + T
956 // bits. This way, the bottom W+T bits of the product are accurate. Then,
957 // when we perform the division by 2^T (which is equivalent to a right shift
958 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
959 // truncated out after the division by 2^T.
961 // In comparison to just directly using the first formula, this technique
962 // is much more efficient; using the first formula requires W * K bits,
963 // but this formula less than W + K bits. Also, the first formula requires
964 // a division step, whereas this formula only requires multiplies and shifts.
966 // It doesn't matter whether the subtraction step is done in the calculation
967 // width or the input iteration count's width; if the subtraction overflows,
968 // the result must be zero anyway. We prefer here to do it in the width of
969 // the induction variable because it helps a lot for certain cases; CodeGen
970 // isn't smart enough to ignore the overflow, which leads to much less
971 // efficient code if the width of the subtraction is wider than the native
974 // (It's possible to not widen at all by pulling out factors of 2 before
975 // the multiplication; for example, K=2 can be calculated as
976 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
977 // extra arithmetic, so it's not an obvious win, and it gets
978 // much more complicated for K > 3.)
980 // Protection from insane SCEVs; this bound is conservative,
981 // but it probably doesn't matter.
983 return SE.getCouldNotCompute();
985 unsigned W = SE.getTypeSizeInBits(ResultTy);
987 // Calculate K! / 2^T and T; we divide out the factors of two before
988 // multiplying for calculating K! / 2^T to avoid overflow.
989 // Other overflow doesn't matter because we only care about the bottom
990 // W bits of the result.
991 APInt OddFactorial(W, 1);
993 for (unsigned i = 3; i <= K; ++i) {
995 unsigned TwoFactors = Mult.countTrailingZeros();
997 Mult.lshrInPlace(TwoFactors);
998 OddFactorial *= Mult;
1001 // We need at least W + T bits for the multiplication step
1002 unsigned CalculationBits = W + T;
1004 // Calculate 2^T, at width T+W.
1005 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1007 // Calculate the multiplicative inverse of K! / 2^T;
1008 // this multiplication factor will perform the exact division by
1010 APInt Mod = APInt::getSignedMinValue(W+1);
1011 APInt MultiplyFactor = OddFactorial.zext(W+1);
1012 MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1013 MultiplyFactor = MultiplyFactor.trunc(W);
1015 // Calculate the product, at width T+W
1016 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1018 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1019 for (unsigned i = 1; i != K; ++i) {
1020 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1021 Dividend = SE.getMulExpr(Dividend,
1022 SE.getTruncateOrZeroExtend(S, CalculationTy));
1026 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1028 // Truncate the result, and divide by K! / 2^T.
1030 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1031 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1034 /// Return the value of this chain of recurrences at the specified iteration
1035 /// number. We can evaluate this recurrence by multiplying each element in the
1036 /// chain by the binomial coefficient corresponding to it. In other words, we
1037 /// can evaluate {A,+,B,+,C,+,D} as:
1039 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1041 /// where BC(It, k) stands for binomial coefficient.
1042 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
1043 ScalarEvolution &SE) const {
1044 return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1048 SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands,
1049 const SCEV *It, ScalarEvolution &SE) {
1050 assert(Operands.size() > 0);
1051 const SCEV *Result = Operands[0];
1052 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1053 // The computation is correct in the face of overflow provided that the
1054 // multiplication is performed _after_ the evaluation of the binomial
1056 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1057 if (isa<SCEVCouldNotCompute>(Coeff))
1060 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1065 //===----------------------------------------------------------------------===//
1066 // SCEV Expression folder implementations
1067 //===----------------------------------------------------------------------===//
1069 const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
1071 assert(Depth <= 1 &&
1072 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1074 // We could be called with an integer-typed operands during SCEV rewrites.
1075 // Since the operand is an integer already, just perform zext/trunc/self cast.
1076 if (!Op->getType()->isPointerTy())
1079 // What would be an ID for such a SCEV cast expression?
1080 FoldingSetNodeID ID;
1081 ID.AddInteger(scPtrToInt);
1086 // Is there already an expression for such a cast?
1087 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1090 // It isn't legal for optimizations to construct new ptrtoint expressions
1091 // for non-integral pointers.
1092 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1093 return getCouldNotCompute();
1095 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1097 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1098 // is sufficiently wide to represent all possible pointer values.
1099 // We could theoretically teach SCEV to truncate wider pointers, but
1100 // that isn't implemented for now.
1101 if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(Op->getType())) !=
1102 getDataLayout().getTypeSizeInBits(IntPtrTy))
1103 return getCouldNotCompute();
1105 // If not, is this expression something we can't reduce any further?
1106 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1107 // Perform some basic constant folding. If the operand of the ptr2int cast
1108 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1109 // left as-is), but produce a zero constant.
1110 // NOTE: We could handle a more general case, but lack motivational cases.
1111 if (isa<ConstantPointerNull>(U->getValue()))
1112 return getZero(IntPtrTy);
1114 // Create an explicit cast node.
1115 // We can reuse the existing insert position since if we get here,
1116 // we won't have made any changes which would invalidate it.
1117 SCEV *S = new (SCEVAllocator)
1118 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1119 UniqueSCEVs.InsertNode(S, IP);
1120 registerUser(S, Op);
1124 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1125 "non-SCEVUnknown's.");
1127 // Otherwise, we've got some expression that is more complex than just a
1128 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1129 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1130 // only, and the expressions must otherwise be integer-typed.
1131 // So sink the cast down to the SCEVUnknown's.
1133 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1134 /// which computes a pointer-typed value, and rewrites the whole expression
1135 /// tree so that *all* the computations are done on integers, and the only
1136 /// pointer-typed operands in the expression are SCEVUnknown.
1137 class SCEVPtrToIntSinkingRewriter
1138 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1139 using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
1142 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1144 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1145 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1146 return Rewriter.visit(Scev);
1149 const SCEV *visit(const SCEV *S) {
1150 Type *STy = S->getType();
1151 // If the expression is not pointer-typed, just keep it as-is.
1152 if (!STy->isPointerTy())
1154 // Else, recursively sink the cast down into it.
1155 return Base::visit(S);
1158 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1159 SmallVector<const SCEV *, 2> Operands;
1160 bool Changed = false;
1161 for (auto *Op : Expr->operands()) {
1162 Operands.push_back(visit(Op));
1163 Changed |= Op != Operands.back();
1165 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1168 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1169 SmallVector<const SCEV *, 2> Operands;
1170 bool Changed = false;
1171 for (auto *Op : Expr->operands()) {
1172 Operands.push_back(visit(Op));
1173 Changed |= Op != Operands.back();
1175 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1178 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1179 assert(Expr->getType()->isPointerTy() &&
1180 "Should only reach pointer-typed SCEVUnknown's.");
1181 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1185 // And actually perform the cast sinking.
1186 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1187 assert(IntOp->getType()->isIntegerTy() &&
1188 "We must have succeeded in sinking the cast, "
1189 "and ending up with an integer-typed expression!");
1193 const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) {
1194 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1196 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1197 if (isa<SCEVCouldNotCompute>(IntOp))
1200 return getTruncateOrZeroExtend(IntOp, Ty);
1203 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
1205 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1206 "This is not a truncating conversion!");
1207 assert(isSCEVable(Ty) &&
1208 "This is not a conversion to a SCEVable type!");
1209 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1210 Ty = getEffectiveSCEVType(Ty);
1212 FoldingSetNodeID ID;
1213 ID.AddInteger(scTruncate);
1217 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1219 // Fold if the operand is constant.
1220 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1222 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1224 // trunc(trunc(x)) --> trunc(x)
1225 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1226 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1228 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1229 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1230 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1232 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1233 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1234 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1236 if (Depth > MaxCastDepth) {
1238 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1239 UniqueSCEVs.InsertNode(S, IP);
1240 registerUser(S, Op);
1244 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1245 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1246 // if after transforming we have at most one truncate, not counting truncates
1247 // that replace other casts.
1248 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1249 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1250 SmallVector<const SCEV *, 4> Operands;
1251 unsigned numTruncs = 0;
1252 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1254 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1255 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1256 isa<SCEVTruncateExpr>(S))
1258 Operands.push_back(S);
1260 if (numTruncs < 2) {
1261 if (isa<SCEVAddExpr>(Op))
1262 return getAddExpr(Operands);
1263 else if (isa<SCEVMulExpr>(Op))
1264 return getMulExpr(Operands);
1266 llvm_unreachable("Unexpected SCEV type for Op.");
1268 // Although we checked in the beginning that ID is not in the cache, it is
1269 // possible that during recursion and different modification ID was inserted
1270 // into the cache. So if we find it, just return it.
1271 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1275 // If the input value is a chrec scev, truncate the chrec's operands.
1276 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1277 SmallVector<const SCEV *, 4> Operands;
1278 for (const SCEV *Op : AddRec->operands())
1279 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1280 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1283 // Return zero if truncating to known zeros.
1284 uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1285 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1288 // The cast wasn't folded; create an explicit cast node. We can reuse
1289 // the existing insert position since if we get here, we won't have
1290 // made any changes which would invalidate it.
1291 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1293 UniqueSCEVs.InsertNode(S, IP);
1294 registerUser(S, Op);
1298 // Get the limit of a recurrence such that incrementing by Step cannot cause
1299 // signed overflow as long as the value of the recurrence within the
1300 // loop does not exceed this limit before incrementing.
1301 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1302 ICmpInst::Predicate *Pred,
1303 ScalarEvolution *SE) {
1304 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1305 if (SE->isKnownPositive(Step)) {
1306 *Pred = ICmpInst::ICMP_SLT;
1307 return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1308 SE->getSignedRangeMax(Step));
1310 if (SE->isKnownNegative(Step)) {
1311 *Pred = ICmpInst::ICMP_SGT;
1312 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1313 SE->getSignedRangeMin(Step));
1318 // Get the limit of a recurrence such that incrementing by Step cannot cause
1319 // unsigned overflow as long as the value of the recurrence within the loop does
1320 // not exceed this limit before incrementing.
1321 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1322 ICmpInst::Predicate *Pred,
1323 ScalarEvolution *SE) {
1324 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1325 *Pred = ICmpInst::ICMP_ULT;
1327 return SE->getConstant(APInt::getMinValue(BitWidth) -
1328 SE->getUnsignedRangeMax(Step));
1333 struct ExtendOpTraitsBase {
1334 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1338 // Used to make code generic over signed and unsigned overflow.
1339 template <typename ExtendOp> struct ExtendOpTraits {
1342 // static const SCEV::NoWrapFlags WrapType;
1344 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1346 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1347 // ICmpInst::Predicate *Pred,
1348 // ScalarEvolution *SE);
1352 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1353 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1355 static const GetExtendExprTy GetExtendExpr;
1357 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1358 ICmpInst::Predicate *Pred,
1359 ScalarEvolution *SE) {
1360 return getSignedOverflowLimitForStep(Step, Pred, SE);
1364 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1365 SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1368 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1369 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1371 static const GetExtendExprTy GetExtendExpr;
1373 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1380 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1381 SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1383 } // end anonymous namespace
1385 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1386 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1387 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1388 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1389 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1390 // expression "Step + sext/zext(PreIncAR)" is congruent with
1391 // "sext/zext(PostIncAR)"
1392 template <typename ExtendOpTy>
1393 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1394 ScalarEvolution *SE, unsigned Depth) {
1395 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1396 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1398 const Loop *L = AR->getLoop();
1399 const SCEV *Start = AR->getStart();
1400 const SCEV *Step = AR->getStepRecurrence(*SE);
1402 // Check for a simple looking step prior to loop entry.
1403 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1407 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1408 // subtraction is expensive. For this purpose, perform a quick and dirty
1409 // difference, by checking for Step in the operand list.
1410 SmallVector<const SCEV *, 4> DiffOps;
1411 for (const SCEV *Op : SA->operands())
1413 DiffOps.push_back(Op);
1415 if (DiffOps.size() == SA->getNumOperands())
1418 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1421 // 1. NSW/NUW flags on the step increment.
1422 auto PreStartFlags =
1423 ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
1424 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1425 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1426 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1428 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1429 // "S+X does not sign/unsign-overflow".
1432 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1433 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1434 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1437 // 2. Direct overflow check on the step operation's expression.
1438 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1439 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1440 const SCEV *OperandExtendedStart =
1441 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1442 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1443 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1444 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1445 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1446 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1447 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1448 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1453 // 3. Loop precondition.
1454 ICmpInst::Predicate Pred;
1455 const SCEV *OverflowLimit =
1456 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1458 if (OverflowLimit &&
1459 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1465 // Get the normalized zero or sign extended expression for this AddRec's Start.
1466 template <typename ExtendOpTy>
1467 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1468 ScalarEvolution *SE,
1470 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1472 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1474 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1476 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1478 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1481 // Try to prove away overflow by looking at "nearby" add recurrences. A
1482 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1483 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1487 // {S,+,X} == {S-T,+,X} + T
1488 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1490 // If ({S-T,+,X} + T) does not overflow ... (1)
1492 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1494 // If {S-T,+,X} does not overflow ... (2)
1496 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1497 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1499 // If (S-T)+T does not overflow ... (3)
1501 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1502 // == {Ext(S),+,Ext(X)} == LHS
1504 // Thus, if (1), (2) and (3) are true for some T, then
1505 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1507 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1508 // does not overflow" restricted to the 0th iteration. Therefore we only need
1509 // to check for (1) and (2).
1511 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1512 // is `Delta` (defined below).
1513 template <typename ExtendOpTy>
1514 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1517 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1519 // We restrict `Start` to a constant to prevent SCEV from spending too much
1520 // time here. It is correct (but more expensive) to continue with a
1521 // non-constant `Start` and do a general SCEV subtraction to compute
1522 // `PreStart` below.
1523 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1527 APInt StartAI = StartC->getAPInt();
1529 for (unsigned Delta : {-2, -1, 1, 2}) {
1530 const SCEV *PreStart = getConstant(StartAI - Delta);
1532 FoldingSetNodeID ID;
1533 ID.AddInteger(scAddRecExpr);
1534 ID.AddPointer(PreStart);
1535 ID.AddPointer(Step);
1539 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1541 // Give up if we don't already have the add recurrence we need because
1542 // actually constructing an add recurrence is relatively expensive.
1543 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1544 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1545 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
1546 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1547 DeltaS, &Pred, this);
1548 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1556 // Finds an integer D for an expression (C + x + y + ...) such that the top
1557 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1558 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1559 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1560 // the (C + x + y + ...) expression is \p WholeAddExpr.
1561 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1562 const SCEVConstant *ConstantTerm,
1563 const SCEVAddExpr *WholeAddExpr) {
1564 const APInt &C = ConstantTerm->getAPInt();
1565 const unsigned BitWidth = C.getBitWidth();
1566 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1567 uint32_t TZ = BitWidth;
1568 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1569 TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1571 // Set D to be as many least significant bits of C as possible while still
1572 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1573 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1575 return APInt(BitWidth, 0);
1578 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1579 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1580 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1581 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1582 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1583 const APInt &ConstantStart,
1585 const unsigned BitWidth = ConstantStart.getBitWidth();
1586 const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1588 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1590 return APInt(BitWidth, 0);
1594 ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1595 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1596 "This is not an extending conversion!");
1597 assert(isSCEVable(Ty) &&
1598 "This is not a conversion to a SCEVable type!");
1599 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1600 Ty = getEffectiveSCEVType(Ty);
1602 // Fold if the operand is constant.
1603 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1605 cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1607 // zext(zext(x)) --> zext(x)
1608 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1609 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1611 // Before doing any expensive analysis, check to see if we've already
1612 // computed a SCEV for this Op and Ty.
1613 FoldingSetNodeID ID;
1614 ID.AddInteger(scZeroExtend);
1618 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1619 if (Depth > MaxCastDepth) {
1620 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1622 UniqueSCEVs.InsertNode(S, IP);
1623 registerUser(S, Op);
1627 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1628 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1629 // It's possible the bits taken off by the truncate were all zero bits. If
1630 // so, we should be able to simplify this further.
1631 const SCEV *X = ST->getOperand();
1632 ConstantRange CR = getUnsignedRange(X);
1633 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1634 unsigned NewBits = getTypeSizeInBits(Ty);
1635 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1636 CR.zextOrTrunc(NewBits)))
1637 return getTruncateOrZeroExtend(X, Ty, Depth);
1640 // If the input value is a chrec scev, and we can prove that the value
1641 // did not overflow the old, smaller, value, we can zero extend all of the
1642 // operands (often constants). This allows analysis of something like
1643 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1644 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1645 if (AR->isAffine()) {
1646 const SCEV *Start = AR->getStart();
1647 const SCEV *Step = AR->getStepRecurrence(*this);
1648 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1649 const Loop *L = AR->getLoop();
1651 if (!AR->hasNoUnsignedWrap()) {
1652 auto NewFlags = proveNoWrapViaConstantRanges(AR);
1653 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1656 // If we have special knowledge that this addrec won't overflow,
1657 // we don't need to do any further analysis.
1658 if (AR->hasNoUnsignedWrap()) {
1660 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1661 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1662 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1665 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1666 // Note that this serves two purposes: It filters out loops that are
1667 // simply not analyzable, and it covers the case where this code is
1668 // being called from within backedge-taken count analysis, such that
1669 // attempting to ask for the backedge-taken count would likely result
1670 // in infinite recursion. In the later case, the analysis code will
1671 // cope with a conservative value, and it will take care to purge
1672 // that value once it has finished.
1673 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1674 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1675 // Manually compute the final value for AR, checking for overflow.
1677 // Check whether the backedge-taken count can be losslessly casted to
1678 // the addrec's type. The count is always unsigned.
1679 const SCEV *CastedMaxBECount =
1680 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1681 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1682 CastedMaxBECount, MaxBECount->getType(), Depth);
1683 if (MaxBECount == RecastedMaxBECount) {
1684 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1685 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1686 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1687 SCEV::FlagAnyWrap, Depth + 1);
1688 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1692 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1693 const SCEV *WideMaxBECount =
1694 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1695 const SCEV *OperandExtendedAdd =
1696 getAddExpr(WideStart,
1697 getMulExpr(WideMaxBECount,
1698 getZeroExtendExpr(Step, WideTy, Depth + 1),
1699 SCEV::FlagAnyWrap, Depth + 1),
1700 SCEV::FlagAnyWrap, Depth + 1);
1701 if (ZAdd == OperandExtendedAdd) {
1702 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1703 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1704 // Return the expression with the addrec on the outside.
1705 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1707 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1708 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1710 // Similar to above, only this time treat the step value as signed.
1711 // This covers loops that count down.
1712 OperandExtendedAdd =
1713 getAddExpr(WideStart,
1714 getMulExpr(WideMaxBECount,
1715 getSignExtendExpr(Step, WideTy, Depth + 1),
1716 SCEV::FlagAnyWrap, Depth + 1),
1717 SCEV::FlagAnyWrap, Depth + 1);
1718 if (ZAdd == OperandExtendedAdd) {
1719 // Cache knowledge of AR NW, which is propagated to this AddRec.
1720 // Negative step causes unsigned wrap, but it still can't self-wrap.
1721 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1722 // Return the expression with the addrec on the outside.
1723 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1725 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1726 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1731 // Normally, in the cases we can prove no-overflow via a
1732 // backedge guarding condition, we can also compute a backedge
1733 // taken count for the loop. The exceptions are assumptions and
1734 // guards present in the loop -- SCEV is not great at exploiting
1735 // these to compute max backedge taken counts, but can still use
1736 // these to prove lack of overflow. Use this fact to avoid
1737 // doing extra work that may not pay off.
1738 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1739 !AC.assumptions().empty()) {
1741 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1742 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1743 if (AR->hasNoUnsignedWrap()) {
1744 // Same as nuw case above - duplicated here to avoid a compile time
1745 // issue. It's not clear that the order of checks does matter, but
1746 // it's one of two issue possible causes for a change which was
1747 // reverted. Be conservative for the moment.
1749 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1754 // For a negative step, we can extend the operands iff doing so only
1755 // traverses values in the range zext([0,UINT_MAX]).
1756 if (isKnownNegative(Step)) {
1757 const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1758 getSignedRangeMin(Step));
1759 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1760 isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
1761 // Cache knowledge of AR NW, which is propagated to this
1762 // AddRec. Negative step causes unsigned wrap, but it
1763 // still can't self-wrap.
1764 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1765 // Return the expression with the addrec on the outside.
1766 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1768 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1774 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1775 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1776 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1777 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1778 const APInt &C = SC->getAPInt();
1779 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1781 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1782 const SCEV *SResidual =
1783 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1784 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1785 return getAddExpr(SZExtD, SZExtR,
1786 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1791 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1794 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1795 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1796 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1800 // zext(A % B) --> zext(A) % zext(B)
1804 if (matchURem(Op, LHS, RHS))
1805 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1806 getZeroExtendExpr(RHS, Ty, Depth + 1));
1809 // zext(A / B) --> zext(A) / zext(B).
1810 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1811 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1812 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1814 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1815 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1816 if (SA->hasNoUnsignedWrap()) {
1817 // If the addition does not unsign overflow then we can, by definition,
1818 // commute the zero extension with the addition operation.
1819 SmallVector<const SCEV *, 4> Ops;
1820 for (const auto *Op : SA->operands())
1821 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1822 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1825 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1826 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1827 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1829 // Often address arithmetics contain expressions like
1830 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1831 // This transformation is useful while proving that such expressions are
1832 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1833 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1834 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1836 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1837 const SCEV *SResidual =
1838 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1839 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1840 return getAddExpr(SZExtD, SZExtR,
1841 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1847 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1848 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1849 if (SM->hasNoUnsignedWrap()) {
1850 // If the multiply does not unsign overflow then we can, by definition,
1851 // commute the zero extension with the multiply operation.
1852 SmallVector<const SCEV *, 4> Ops;
1853 for (const auto *Op : SM->operands())
1854 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1855 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1858 // zext(2^K * (trunc X to iN)) to iM ->
1859 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1863 // zext(2^K * (trunc X to iN)) to iM
1864 // = zext((trunc X to iN) << K) to iM
1865 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1866 // (because shl removes the top K bits)
1867 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1868 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1870 if (SM->getNumOperands() == 2)
1871 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1872 if (MulLHS->getAPInt().isPowerOf2())
1873 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1874 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1875 MulLHS->getAPInt().logBase2();
1876 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1878 getZeroExtendExpr(MulLHS, Ty),
1880 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1881 SCEV::FlagNUW, Depth + 1);
1885 // The cast wasn't folded; create an explicit cast node.
1886 // Recompute the insert position, as it may have been invalidated.
1887 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1888 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1890 UniqueSCEVs.InsertNode(S, IP);
1891 registerUser(S, Op);
1896 ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1897 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1898 "This is not an extending conversion!");
1899 assert(isSCEVable(Ty) &&
1900 "This is not a conversion to a SCEVable type!");
1901 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1902 Ty = getEffectiveSCEVType(Ty);
1904 // Fold if the operand is constant.
1905 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1907 cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1909 // sext(sext(x)) --> sext(x)
1910 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1911 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1913 // sext(zext(x)) --> zext(x)
1914 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1915 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1917 // Before doing any expensive analysis, check to see if we've already
1918 // computed a SCEV for this Op and Ty.
1919 FoldingSetNodeID ID;
1920 ID.AddInteger(scSignExtend);
1924 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1925 // Limit recursion depth.
1926 if (Depth > MaxCastDepth) {
1927 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1929 UniqueSCEVs.InsertNode(S, IP);
1930 registerUser(S, Op);
1934 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1935 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1936 // It's possible the bits taken off by the truncate were all sign bits. If
1937 // so, we should be able to simplify this further.
1938 const SCEV *X = ST->getOperand();
1939 ConstantRange CR = getSignedRange(X);
1940 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1941 unsigned NewBits = getTypeSizeInBits(Ty);
1942 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1943 CR.sextOrTrunc(NewBits)))
1944 return getTruncateOrSignExtend(X, Ty, Depth);
1947 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1948 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1949 if (SA->hasNoSignedWrap()) {
1950 // If the addition does not sign overflow then we can, by definition,
1951 // commute the sign extension with the addition operation.
1952 SmallVector<const SCEV *, 4> Ops;
1953 for (const auto *Op : SA->operands())
1954 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1955 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1958 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1959 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1960 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1962 // For instance, this will bring two seemingly different expressions:
1963 // 1 + sext(5 + 20 * %x + 24 * %y) and
1964 // sext(6 + 20 * %x + 24 * %y)
1965 // to the same form:
1966 // 2 + sext(4 + 20 * %x + 24 * %y)
1967 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1968 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1970 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1971 const SCEV *SResidual =
1972 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1973 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1974 return getAddExpr(SSExtD, SSExtR,
1975 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1980 // If the input value is a chrec scev, and we can prove that the value
1981 // did not overflow the old, smaller, value, we can sign extend all of the
1982 // operands (often constants). This allows analysis of something like
1983 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1984 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1985 if (AR->isAffine()) {
1986 const SCEV *Start = AR->getStart();
1987 const SCEV *Step = AR->getStepRecurrence(*this);
1988 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1989 const Loop *L = AR->getLoop();
1991 if (!AR->hasNoSignedWrap()) {
1992 auto NewFlags = proveNoWrapViaConstantRanges(AR);
1993 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1996 // If we have special knowledge that this addrec won't overflow,
1997 // we don't need to do any further analysis.
1998 if (AR->hasNoSignedWrap()) {
2000 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2001 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2002 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2005 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2006 // Note that this serves two purposes: It filters out loops that are
2007 // simply not analyzable, and it covers the case where this code is
2008 // being called from within backedge-taken count analysis, such that
2009 // attempting to ask for the backedge-taken count would likely result
2010 // in infinite recursion. In the later case, the analysis code will
2011 // cope with a conservative value, and it will take care to purge
2012 // that value once it has finished.
2013 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2014 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2015 // Manually compute the final value for AR, checking for
2018 // Check whether the backedge-taken count can be losslessly casted to
2019 // the addrec's type. The count is always unsigned.
2020 const SCEV *CastedMaxBECount =
2021 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2022 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2023 CastedMaxBECount, MaxBECount->getType(), Depth);
2024 if (MaxBECount == RecastedMaxBECount) {
2025 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2026 // Check whether Start+Step*MaxBECount has no signed overflow.
2027 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2028 SCEV::FlagAnyWrap, Depth + 1);
2029 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2033 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2034 const SCEV *WideMaxBECount =
2035 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2036 const SCEV *OperandExtendedAdd =
2037 getAddExpr(WideStart,
2038 getMulExpr(WideMaxBECount,
2039 getSignExtendExpr(Step, WideTy, Depth + 1),
2040 SCEV::FlagAnyWrap, Depth + 1),
2041 SCEV::FlagAnyWrap, Depth + 1);
2042 if (SAdd == OperandExtendedAdd) {
2043 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2044 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2045 // Return the expression with the addrec on the outside.
2046 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2048 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2049 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2051 // Similar to above, only this time treat the step value as unsigned.
2052 // This covers loops that count up with an unsigned step.
2053 OperandExtendedAdd =
2054 getAddExpr(WideStart,
2055 getMulExpr(WideMaxBECount,
2056 getZeroExtendExpr(Step, WideTy, Depth + 1),
2057 SCEV::FlagAnyWrap, Depth + 1),
2058 SCEV::FlagAnyWrap, Depth + 1);
2059 if (SAdd == OperandExtendedAdd) {
2060 // If AR wraps around then
2062 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2063 // => SAdd != OperandExtendedAdd
2065 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2066 // (SAdd == OperandExtendedAdd => AR is NW)
2068 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2070 // Return the expression with the addrec on the outside.
2071 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2073 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2074 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2079 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2080 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2081 if (AR->hasNoSignedWrap()) {
2082 // Same as nsw case above - duplicated here to avoid a compile time
2083 // issue. It's not clear that the order of checks does matter, but
2084 // it's one of two issue possible causes for a change which was
2085 // reverted. Be conservative for the moment.
2087 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2088 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2089 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2092 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2093 // if D + (C - D + Step * n) could be proven to not signed wrap
2094 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2095 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2096 const APInt &C = SC->getAPInt();
2097 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2099 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2100 const SCEV *SResidual =
2101 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2102 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2103 return getAddExpr(SSExtD, SSExtR,
2104 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
2109 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2110 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2112 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2113 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2114 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2118 // If the input value is provably positive and we could not simplify
2119 // away the sext build a zext instead.
2120 if (isKnownNonNegative(Op))
2121 return getZeroExtendExpr(Op, Ty, Depth + 1);
2123 // The cast wasn't folded; create an explicit cast node.
2124 // Recompute the insert position, as it may have been invalidated.
2125 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2126 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2128 UniqueSCEVs.InsertNode(S, IP);
2129 registerUser(S, { Op });
2133 const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op,
2137 return getTruncateExpr(Op, Ty);
2139 return getZeroExtendExpr(Op, Ty);
2141 return getSignExtendExpr(Op, Ty);
2143 return getPtrToIntExpr(Op, Ty);
2145 llvm_unreachable("Not a SCEV cast expression!");
2149 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2150 /// unspecified bits out to the given type.
2151 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
2153 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2154 "This is not an extending conversion!");
2155 assert(isSCEVable(Ty) &&
2156 "This is not a conversion to a SCEVable type!");
2157 Ty = getEffectiveSCEVType(Ty);
2159 // Sign-extend negative constants.
2160 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2161 if (SC->getAPInt().isNegative())
2162 return getSignExtendExpr(Op, Ty);
2164 // Peel off a truncate cast.
2165 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2166 const SCEV *NewOp = T->getOperand();
2167 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2168 return getAnyExtendExpr(NewOp, Ty);
2169 return getTruncateOrNoop(NewOp, Ty);
2172 // Next try a zext cast. If the cast is folded, use it.
2173 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2174 if (!isa<SCEVZeroExtendExpr>(ZExt))
2177 // Next try a sext cast. If the cast is folded, use it.
2178 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2179 if (!isa<SCEVSignExtendExpr>(SExt))
2182 // Force the cast to be folded into the operands of an addrec.
2183 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2184 SmallVector<const SCEV *, 4> Ops;
2185 for (const SCEV *Op : AR->operands())
2186 Ops.push_back(getAnyExtendExpr(Op, Ty));
2187 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2190 // If the expression is obviously signed, use the sext cast value.
2191 if (isa<SCEVSMaxExpr>(Op))
2194 // Absent any other information, use the zext cast value.
2198 /// Process the given Ops list, which is a list of operands to be added under
2199 /// the given scale, update the given map. This is a helper function for
2200 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2201 /// that would form an add expression like this:
2203 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2205 /// where A and B are constants, update the map with these values:
2207 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2209 /// and add 13 + A*B*29 to AccumulatedConstant.
2210 /// This will allow getAddRecExpr to produce this:
2212 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2214 /// This form often exposes folding opportunities that are hidden in
2215 /// the original operand list.
2217 /// Return true iff it appears that any interesting folding opportunities
2218 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2219 /// the common case where no interesting opportunities are present, and
2220 /// is also used as a check to avoid infinite recursion.
2222 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
2223 SmallVectorImpl<const SCEV *> &NewOps,
2224 APInt &AccumulatedConstant,
2225 const SCEV *const *Ops, size_t NumOperands,
2227 ScalarEvolution &SE) {
2228 bool Interesting = false;
2230 // Iterate over the add operands. They are sorted, with constants first.
2232 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2234 // Pull a buried constant out to the outside.
2235 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2237 AccumulatedConstant += Scale * C->getAPInt();
2240 // Next comes everything else. We're especially interested in multiplies
2241 // here, but they're in the middle, so just visit the rest with one loop.
2242 for (; i != NumOperands; ++i) {
2243 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2244 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2246 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2247 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2248 // A multiplication of a constant with another add; recurse.
2249 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2251 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2252 Add->op_begin(), Add->getNumOperands(),
2255 // A multiplication of a constant with some other value. Update
2257 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2258 const SCEV *Key = SE.getMulExpr(MulOps);
2259 auto Pair = M.insert({Key, NewScale});
2261 NewOps.push_back(Pair.first->first);
2263 Pair.first->second += NewScale;
2264 // The map already had an entry for this value, which may indicate
2265 // a folding opportunity.
2270 // An ordinary operand. Update the map.
2271 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2272 M.insert({Ops[i], Scale});
2274 NewOps.push_back(Pair.first->first);
2276 Pair.first->second += Scale;
2277 // The map already had an entry for this value, which may indicate
2278 // a folding opportunity.
2287 bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2288 const SCEV *LHS, const SCEV *RHS) {
2289 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2290 SCEV::NoWrapFlags, unsigned);
2293 llvm_unreachable("Unsupported binary op");
2294 case Instruction::Add:
2295 Operation = &ScalarEvolution::getAddExpr;
2297 case Instruction::Sub:
2298 Operation = &ScalarEvolution::getMinusSCEV;
2300 case Instruction::Mul:
2301 Operation = &ScalarEvolution::getMulExpr;
2305 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2306 Signed ? &ScalarEvolution::getSignExtendExpr
2307 : &ScalarEvolution::getZeroExtendExpr;
2309 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2310 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2312 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2314 const SCEV *A = (this->*Extension)(
2315 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2316 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2317 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2318 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2322 std::pair<SCEV::NoWrapFlags, bool /*Deduced*/>
2323 ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
2324 const OverflowingBinaryOperator *OBO) {
2325 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2327 if (OBO->hasNoUnsignedWrap())
2328 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2329 if (OBO->hasNoSignedWrap())
2330 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2332 bool Deduced = false;
2334 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2335 return {Flags, Deduced};
2337 if (OBO->getOpcode() != Instruction::Add &&
2338 OBO->getOpcode() != Instruction::Sub &&
2339 OBO->getOpcode() != Instruction::Mul)
2340 return {Flags, Deduced};
2342 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2343 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2345 if (!OBO->hasNoUnsignedWrap() &&
2346 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2347 /* Signed */ false, LHS, RHS)) {
2348 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2352 if (!OBO->hasNoSignedWrap() &&
2353 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2354 /* Signed */ true, LHS, RHS)) {
2355 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2359 return {Flags, Deduced};
2362 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2363 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2364 // can't-overflow flags for the operation if possible.
2365 static SCEV::NoWrapFlags
2366 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
2367 const ArrayRef<const SCEV *> Ops,
2368 SCEV::NoWrapFlags Flags) {
2369 using namespace std::placeholders;
2371 using OBO = OverflowingBinaryOperator;
2374 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2376 assert(CanAnalyze && "don't call from other places!");
2378 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2379 SCEV::NoWrapFlags SignOrUnsignWrap =
2380 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2382 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2383 auto IsKnownNonNegative = [&](const SCEV *S) {
2384 return SE->isKnownNonNegative(S);
2387 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2389 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2391 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2393 if (SignOrUnsignWrap != SignOrUnsignMask &&
2394 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2395 isa<SCEVConstant>(Ops[0])) {
2400 return Instruction::Add;
2402 return Instruction::Mul;
2404 llvm_unreachable("Unexpected SCEV op.");
2408 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2410 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2411 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2412 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2413 Opcode, C, OBO::NoSignedWrap);
2414 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2415 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2418 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2419 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2420 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2421 Opcode, C, OBO::NoUnsignedWrap);
2422 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2423 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2427 // <0,+,nonnegative><nw> is also nuw
2428 // TODO: Add corresponding nsw case
2429 if (Type == scAddRecExpr && ScalarEvolution::hasFlags(Flags, SCEV::FlagNW) &&
2430 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2431 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2432 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2434 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2435 if (Type == scMulExpr && !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) &&
2437 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2438 if (UDiv->getOperand(1) == Ops[1])
2439 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2440 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2441 if (UDiv->getOperand(1) == Ops[0])
2442 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2448 bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
2449 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2452 /// Get a canonical add expression, or something simpler if possible.
2453 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2454 SCEV::NoWrapFlags OrigFlags,
2456 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2457 "only nuw or nsw allowed");
2458 assert(!Ops.empty() && "Cannot get empty add!");
2459 if (Ops.size() == 1) return Ops[0];
2461 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2462 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2463 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2464 "SCEVAddExpr operand types don't match!");
2465 unsigned NumPtrs = count_if(
2466 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2467 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2470 // Sort by complexity, this groups all similar expression types together.
2471 GroupByComplexity(Ops, &LI, DT);
2473 // If there are any constants, fold them together.
2475 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2477 assert(Idx < Ops.size());
2478 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2479 // We found two constants, fold them together!
2480 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2481 if (Ops.size() == 2) return Ops[0];
2482 Ops.erase(Ops.begin()+1); // Erase the folded element
2483 LHSC = cast<SCEVConstant>(Ops[0]);
2486 // If we are left with a constant zero being added, strip it off.
2487 if (LHSC->getValue()->isZero()) {
2488 Ops.erase(Ops.begin());
2492 if (Ops.size() == 1) return Ops[0];
2495 // Delay expensive flag strengthening until necessary.
2496 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2497 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2500 // Limit recursion calls depth.
2501 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2502 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2504 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2505 // Don't strengthen flags if we have no new information.
2506 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2507 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2508 Add->setNoWrapFlags(ComputeFlags(Ops));
2512 // Okay, check to see if the same value occurs in the operand list more than
2513 // once. If so, merge them together into an multiply expression. Since we
2514 // sorted the list, these values are required to be adjacent.
2515 Type *Ty = Ops[0]->getType();
2516 bool FoundMatch = false;
2517 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2518 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2519 // Scan ahead to count how many equal operands there are.
2521 while (i+Count != e && Ops[i+Count] == Ops[i])
2523 // Merge the values into a multiply.
2524 const SCEV *Scale = getConstant(Ty, Count);
2525 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2526 if (Ops.size() == Count)
2529 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2530 --i; e -= Count - 1;
2534 return getAddExpr(Ops, OrigFlags, Depth + 1);
2536 // Check for truncates. If all the operands are truncated from the same
2537 // type, see if factoring out the truncate would permit the result to be
2538 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2539 // if the contents of the resulting outer trunc fold to something simple.
2540 auto FindTruncSrcType = [&]() -> Type * {
2541 // We're ultimately looking to fold an addrec of truncs and muls of only
2542 // constants and truncs, so if we find any other types of SCEV
2543 // as operands of the addrec then we bail and return nullptr here.
2544 // Otherwise, we return the type of the operand of a trunc that we find.
2545 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2546 return T->getOperand()->getType();
2547 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2548 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2549 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2550 return T->getOperand()->getType();
2554 if (auto *SrcType = FindTruncSrcType()) {
2555 SmallVector<const SCEV *, 8> LargeOps;
2557 // Check all the operands to see if they can be represented in the
2558 // source type of the truncate.
2559 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2560 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2561 if (T->getOperand()->getType() != SrcType) {
2565 LargeOps.push_back(T->getOperand());
2566 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2567 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2568 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2569 SmallVector<const SCEV *, 8> LargeMulOps;
2570 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2571 if (const SCEVTruncateExpr *T =
2572 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2573 if (T->getOperand()->getType() != SrcType) {
2577 LargeMulOps.push_back(T->getOperand());
2578 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2579 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2586 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2593 // Evaluate the expression in the larger type.
2594 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2595 // If it folds to something simple, use it. Otherwise, don't.
2596 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2597 return getTruncateExpr(Fold, Ty);
2601 if (Ops.size() == 2) {
2602 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2603 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2605 const SCEV *A = Ops[0];
2606 const SCEV *B = Ops[1];
2607 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2608 auto *C = dyn_cast<SCEVConstant>(A);
2609 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2610 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2611 auto C2 = C->getAPInt();
2612 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2614 APInt ConstAdd = C1 + C2;
2615 auto AddFlags = AddExpr->getNoWrapFlags();
2616 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2617 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2620 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2623 // Adding a constant with the same sign and small magnitude is NSW, if the
2624 // original AddExpr was NSW.
2625 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2626 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2627 ConstAdd.abs().ule(C1.abs())) {
2629 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2632 if (PreservedFlags != SCEV::FlagAnyWrap) {
2633 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2634 NewOps[0] = getConstant(ConstAdd);
2635 return getAddExpr(NewOps, PreservedFlags);
2640 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2641 if (Ops.size() == 2) {
2642 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2643 if (Mul && Mul->getNumOperands() == 2 &&
2644 Mul->getOperand(0)->isAllOnesValue()) {
2647 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2648 return getMulExpr(Y, getUDivExpr(X, Y));
2653 // Skip past any other cast SCEVs.
2654 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2657 // If there are add operands they would be next.
2658 if (Idx < Ops.size()) {
2659 bool DeletedAdd = false;
2660 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2661 // common NUW flag for expression after inlining. Other flags cannot be
2662 // preserved, because they may depend on the original order of operations.
2663 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2664 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2665 if (Ops.size() > AddOpsInlineThreshold ||
2666 Add->getNumOperands() > AddOpsInlineThreshold)
2668 // If we have an add, expand the add operands onto the end of the operands
2670 Ops.erase(Ops.begin()+Idx);
2671 Ops.append(Add->op_begin(), Add->op_end());
2673 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2676 // If we deleted at least one add, we added operands to the end of the list,
2677 // and they are not necessarily sorted. Recurse to resort and resimplify
2678 // any operands we just acquired.
2680 return getAddExpr(Ops, CommonFlags, Depth + 1);
2683 // Skip over the add expression until we get to a multiply.
2684 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2687 // Check to see if there are any folding opportunities present with
2688 // operands multiplied by constant values.
2689 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2690 uint64_t BitWidth = getTypeSizeInBits(Ty);
2691 DenseMap<const SCEV *, APInt> M;
2692 SmallVector<const SCEV *, 8> NewOps;
2693 APInt AccumulatedConstant(BitWidth, 0);
2694 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2695 Ops.data(), Ops.size(),
2696 APInt(BitWidth, 1), *this)) {
2697 struct APIntCompare {
2698 bool operator()(const APInt &LHS, const APInt &RHS) const {
2699 return LHS.ult(RHS);
2703 // Some interesting folding opportunity is present, so its worthwhile to
2704 // re-generate the operands list. Group the operands by constant scale,
2705 // to avoid multiplying by the same constant scale multiple times.
2706 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2707 for (const SCEV *NewOp : NewOps)
2708 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2709 // Re-generate the operands list.
2711 if (AccumulatedConstant != 0)
2712 Ops.push_back(getConstant(AccumulatedConstant));
2713 for (auto &MulOp : MulOpLists) {
2714 if (MulOp.first == 1) {
2715 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2716 } else if (MulOp.first != 0) {
2717 Ops.push_back(getMulExpr(
2718 getConstant(MulOp.first),
2719 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2720 SCEV::FlagAnyWrap, Depth + 1));
2725 if (Ops.size() == 1)
2727 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2731 // If we are adding something to a multiply expression, make sure the
2732 // something is not already an operand of the multiply. If so, merge it into
2734 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2735 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2736 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2737 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2738 if (isa<SCEVConstant>(MulOpSCEV))
2740 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2741 if (MulOpSCEV == Ops[AddOp]) {
2742 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2743 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2744 if (Mul->getNumOperands() != 2) {
2745 // If the multiply has more than two operands, we must get the
2747 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2748 Mul->op_begin()+MulOp);
2749 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2750 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2752 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2753 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2754 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2755 SCEV::FlagAnyWrap, Depth + 1);
2756 if (Ops.size() == 2) return OuterMul;
2758 Ops.erase(Ops.begin()+AddOp);
2759 Ops.erase(Ops.begin()+Idx-1);
2761 Ops.erase(Ops.begin()+Idx);
2762 Ops.erase(Ops.begin()+AddOp-1);
2764 Ops.push_back(OuterMul);
2765 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2768 // Check this multiply against other multiplies being added together.
2769 for (unsigned OtherMulIdx = Idx+1;
2770 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2772 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2773 // If MulOp occurs in OtherMul, we can fold the two multiplies
2775 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2776 OMulOp != e; ++OMulOp)
2777 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2778 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2779 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2780 if (Mul->getNumOperands() != 2) {
2781 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2782 Mul->op_begin()+MulOp);
2783 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2784 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2786 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2787 if (OtherMul->getNumOperands() != 2) {
2788 SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2789 OtherMul->op_begin()+OMulOp);
2790 MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2791 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2793 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2794 const SCEV *InnerMulSum =
2795 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2796 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2797 SCEV::FlagAnyWrap, Depth + 1);
2798 if (Ops.size() == 2) return OuterMul;
2799 Ops.erase(Ops.begin()+Idx);
2800 Ops.erase(Ops.begin()+OtherMulIdx-1);
2801 Ops.push_back(OuterMul);
2802 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2808 // If there are any add recurrences in the operands list, see if any other
2809 // added values are loop invariant. If so, we can fold them into the
2811 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2814 // Scan over all recurrences, trying to fold loop invariants into them.
2815 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2816 // Scan all of the other operands to this add and add them to the vector if
2817 // they are loop invariant w.r.t. the recurrence.
2818 SmallVector<const SCEV *, 8> LIOps;
2819 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2820 const Loop *AddRecLoop = AddRec->getLoop();
2821 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2822 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2823 LIOps.push_back(Ops[i]);
2824 Ops.erase(Ops.begin()+i);
2828 // If we found some loop invariants, fold them into the recurrence.
2829 if (!LIOps.empty()) {
2830 // Compute nowrap flags for the addition of the loop-invariant ops and
2831 // the addrec. Temporarily push it as an operand for that purpose. These
2832 // flags are valid in the scope of the addrec only.
2833 LIOps.push_back(AddRec);
2834 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2837 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2838 LIOps.push_back(AddRec->getStart());
2840 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2842 // It is not in general safe to propagate flags valid on an add within
2843 // the addrec scope to one outside it. We must prove that the inner
2844 // scope is guaranteed to execute if the outer one does to be able to
2845 // safely propagate. We know the program is undefined if poison is
2846 // produced on the inner scoped addrec. We also know that *for this use*
2847 // the outer scoped add can't overflow (because of the flags we just
2848 // computed for the inner scoped add) without the program being undefined.
2849 // Proving that entry to the outer scope neccesitates entry to the inner
2850 // scope, thus proves the program undefined if the flags would be violated
2851 // in the outer scope.
2852 SCEV::NoWrapFlags AddFlags = Flags;
2853 if (AddFlags != SCEV::FlagAnyWrap) {
2854 auto *DefI = getDefiningScopeBound(LIOps);
2855 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2856 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2857 AddFlags = SCEV::FlagAnyWrap;
2859 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2861 // Build the new addrec. Propagate the NUW and NSW flags if both the
2862 // outer add and the inner addrec are guaranteed to have no overflow.
2863 // Always propagate NW.
2864 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2865 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2867 // If all of the other operands were loop invariant, we are done.
2868 if (Ops.size() == 1) return NewRec;
2870 // Otherwise, add the folded AddRec by the non-invariant parts.
2871 for (unsigned i = 0;; ++i)
2872 if (Ops[i] == AddRec) {
2876 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2879 // Okay, if there weren't any loop invariants to be folded, check to see if
2880 // there are multiple AddRec's with the same loop induction variable being
2881 // added together. If so, we can fold them.
2882 for (unsigned OtherIdx = Idx+1;
2883 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2885 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2886 // so that the 1st found AddRecExpr is dominated by all others.
2887 assert(DT.dominates(
2888 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2889 AddRec->getLoop()->getHeader()) &&
2890 "AddRecExprs are not sorted in reverse dominance order?");
2891 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2892 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2893 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2894 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2896 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2897 if (OtherAddRec->getLoop() == AddRecLoop) {
2898 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2900 if (i >= AddRecOps.size()) {
2901 AddRecOps.append(OtherAddRec->op_begin()+i,
2902 OtherAddRec->op_end());
2905 SmallVector<const SCEV *, 2> TwoOps = {
2906 AddRecOps[i], OtherAddRec->getOperand(i)};
2907 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2909 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2912 // Step size has changed, so we cannot guarantee no self-wraparound.
2913 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2914 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2918 // Otherwise couldn't fold anything into this recurrence. Move onto the
2922 // Okay, it looks like we really DO need an add expr. Check to see if we
2923 // already have one, otherwise create a new one.
2924 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2928 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2929 SCEV::NoWrapFlags Flags) {
2930 FoldingSetNodeID ID;
2931 ID.AddInteger(scAddExpr);
2932 for (const SCEV *Op : Ops)
2936 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2938 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2939 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2940 S = new (SCEVAllocator)
2941 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2942 UniqueSCEVs.InsertNode(S, IP);
2943 registerUser(S, Ops);
2945 S->setNoWrapFlags(Flags);
2950 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2951 const Loop *L, SCEV::NoWrapFlags Flags) {
2952 FoldingSetNodeID ID;
2953 ID.AddInteger(scAddRecExpr);
2954 for (const SCEV *Op : Ops)
2959 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2961 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2962 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2963 S = new (SCEVAllocator)
2964 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2965 UniqueSCEVs.InsertNode(S, IP);
2966 LoopUsers[L].push_back(S);
2967 registerUser(S, Ops);
2969 setNoWrapFlags(S, Flags);
2974 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2975 SCEV::NoWrapFlags Flags) {
2976 FoldingSetNodeID ID;
2977 ID.AddInteger(scMulExpr);
2978 for (const SCEV *Op : Ops)
2982 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2984 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2985 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2986 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2988 UniqueSCEVs.InsertNode(S, IP);
2989 registerUser(S, Ops);
2991 S->setNoWrapFlags(Flags);
2995 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2997 if (j > 1 && k / j != i) Overflow = true;
3001 /// Compute the result of "n choose k", the binomial coefficient. If an
3002 /// intermediate computation overflows, Overflow will be set and the return will
3003 /// be garbage. Overflow is not cleared on absence of overflow.
3004 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3005 // We use the multiplicative formula:
3006 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3007 // At each iteration, we take the n-th term of the numeral and divide by the
3008 // (k-n)th term of the denominator. This division will always produce an
3009 // integral result, and helps reduce the chance of overflow in the
3010 // intermediate computations. However, we can still overflow even when the
3011 // final result would fit.
3013 if (n == 0 || n == k) return 1;
3014 if (k > n) return 0;
3020 for (uint64_t i = 1; i <= k; ++i) {
3021 r = umul_ov(r, n-(i-1), Overflow);
3027 /// Determine if any of the operands in this SCEV are a constant or if
3028 /// any of the add or multiply expressions in this SCEV contain a constant.
3029 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3030 struct FindConstantInAddMulChain {
3031 bool FoundConstant = false;
3033 bool follow(const SCEV *S) {
3034 FoundConstant |= isa<SCEVConstant>(S);
3035 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3038 bool isDone() const {
3039 return FoundConstant;
3043 FindConstantInAddMulChain F;
3044 SCEVTraversal<FindConstantInAddMulChain> ST(F);
3045 ST.visitAll(StartExpr);
3046 return F.FoundConstant;
3049 /// Get a canonical multiply expression, or something simpler if possible.
3050 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
3051 SCEV::NoWrapFlags OrigFlags,
3053 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3054 "only nuw or nsw allowed");
3055 assert(!Ops.empty() && "Cannot get empty mul!");
3056 if (Ops.size() == 1) return Ops[0];
3058 Type *ETy = Ops[0]->getType();
3059 assert(!ETy->isPointerTy());
3060 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3061 assert(Ops[i]->getType() == ETy &&
3062 "SCEVMulExpr operand types don't match!");
3065 // Sort by complexity, this groups all similar expression types together.
3066 GroupByComplexity(Ops, &LI, DT);
3068 // If there are any constants, fold them together.
3070 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3072 assert(Idx < Ops.size());
3073 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3074 // We found two constants, fold them together!
3075 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3076 if (Ops.size() == 2) return Ops[0];
3077 Ops.erase(Ops.begin()+1); // Erase the folded element
3078 LHSC = cast<SCEVConstant>(Ops[0]);
3081 // If we have a multiply of zero, it will always be zero.
3082 if (LHSC->getValue()->isZero())
3085 // If we are left with a constant one being multiplied, strip it off.
3086 if (LHSC->getValue()->isOne()) {
3087 Ops.erase(Ops.begin());
3091 if (Ops.size() == 1)
3095 // Delay expensive flag strengthening until necessary.
3096 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3097 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3100 // Limit recursion calls depth.
3101 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3102 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3104 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3105 // Don't strengthen flags if we have no new information.
3106 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3107 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3108 Mul->setNoWrapFlags(ComputeFlags(Ops));
3112 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3113 if (Ops.size() == 2) {
3114 // C1*(C2+V) -> C1*C2 + C1*V
3115 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3116 // If any of Add's ops are Adds or Muls with a constant, apply this
3117 // transformation as well.
3119 // TODO: There are some cases where this transformation is not
3120 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3121 // this transformation should be narrowed down.
3122 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3123 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3124 SCEV::FlagAnyWrap, Depth + 1);
3125 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3126 SCEV::FlagAnyWrap, Depth + 1);
3127 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3130 if (Ops[0]->isAllOnesValue()) {
3131 // If we have a mul by -1 of an add, try distributing the -1 among the
3133 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3134 SmallVector<const SCEV *, 4> NewOps;
3135 bool AnyFolded = false;
3136 for (const SCEV *AddOp : Add->operands()) {
3137 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3139 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3140 NewOps.push_back(Mul);
3143 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3144 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3145 // Negation preserves a recurrence's no self-wrap property.
3146 SmallVector<const SCEV *, 4> Operands;
3147 for (const SCEV *AddRecOp : AddRec->operands())
3148 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3151 return getAddRecExpr(Operands, AddRec->getLoop(),
3152 AddRec->getNoWrapFlags(SCEV::FlagNW));
3158 // Skip over the add expression until we get to a multiply.
3159 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3162 // If there are mul operands inline them all into this expression.
3163 if (Idx < Ops.size()) {
3164 bool DeletedMul = false;
3165 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3166 if (Ops.size() > MulOpsInlineThreshold)
3168 // If we have an mul, expand the mul operands onto the end of the
3170 Ops.erase(Ops.begin()+Idx);
3171 Ops.append(Mul->op_begin(), Mul->op_end());
3175 // If we deleted at least one mul, we added operands to the end of the
3176 // list, and they are not necessarily sorted. Recurse to resort and
3177 // resimplify any operands we just acquired.
3179 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3182 // If there are any add recurrences in the operands list, see if any other
3183 // added values are loop invariant. If so, we can fold them into the
3185 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3188 // Scan over all recurrences, trying to fold loop invariants into them.
3189 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3190 // Scan all of the other operands to this mul and add them to the vector
3191 // if they are loop invariant w.r.t. the recurrence.
3192 SmallVector<const SCEV *, 8> LIOps;
3193 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3194 const Loop *AddRecLoop = AddRec->getLoop();
3195 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3196 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3197 LIOps.push_back(Ops[i]);
3198 Ops.erase(Ops.begin()+i);
3202 // If we found some loop invariants, fold them into the recurrence.
3203 if (!LIOps.empty()) {
3204 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3205 SmallVector<const SCEV *, 4> NewOps;
3206 NewOps.reserve(AddRec->getNumOperands());
3207 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3208 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3209 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3210 SCEV::FlagAnyWrap, Depth + 1));
3212 // Build the new addrec. Propagate the NUW and NSW flags if both the
3213 // outer mul and the inner addrec are guaranteed to have no overflow.
3215 // No self-wrap cannot be guaranteed after changing the step size, but
3216 // will be inferred if either NUW or NSW is true.
3217 SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3218 const SCEV *NewRec = getAddRecExpr(
3219 NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3221 // If all of the other operands were loop invariant, we are done.
3222 if (Ops.size() == 1) return NewRec;
3224 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3225 for (unsigned i = 0;; ++i)
3226 if (Ops[i] == AddRec) {
3230 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3233 // Okay, if there weren't any loop invariants to be folded, check to see
3234 // if there are multiple AddRec's with the same loop induction variable
3235 // being multiplied together. If so, we can fold them.
3237 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3238 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3239 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3240 // ]]],+,...up to x=2n}.
3241 // Note that the arguments to choose() are always integers with values
3242 // known at compile time, never SCEV objects.
3244 // The implementation avoids pointless extra computations when the two
3245 // addrec's are of different length (mathematically, it's equivalent to
3246 // an infinite stream of zeros on the right).
3247 bool OpsModified = false;
3248 for (unsigned OtherIdx = Idx+1;
3249 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3251 const SCEVAddRecExpr *OtherAddRec =
3252 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3253 if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3256 // Limit max number of arguments to avoid creation of unreasonably big
3257 // SCEVAddRecs with very complex operands.
3258 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3259 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3262 bool Overflow = false;
3263 Type *Ty = AddRec->getType();
3264 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3265 SmallVector<const SCEV*, 7> AddRecOps;
3266 for (int x = 0, xe = AddRec->getNumOperands() +
3267 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3268 SmallVector <const SCEV *, 7> SumOps;
3269 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3270 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3271 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3272 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3273 z < ze && !Overflow; ++z) {
3274 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3276 if (LargerThan64Bits)
3277 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3279 Coeff = Coeff1*Coeff2;
3280 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3281 const SCEV *Term1 = AddRec->getOperand(y-z);
3282 const SCEV *Term2 = OtherAddRec->getOperand(z);
3283 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3284 SCEV::FlagAnyWrap, Depth + 1));
3288 SumOps.push_back(getZero(Ty));
3289 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3292 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3294 if (Ops.size() == 2) return NewAddRec;
3295 Ops[Idx] = NewAddRec;
3296 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3298 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3304 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3306 // Otherwise couldn't fold anything into this recurrence. Move onto the
3310 // Okay, it looks like we really DO need an mul expr. Check to see if we
3311 // already have one, otherwise create a new one.
3312 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3315 /// Represents an unsigned remainder expression based on unsigned division.
3316 const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
3318 assert(getEffectiveSCEVType(LHS->getType()) ==
3319 getEffectiveSCEVType(RHS->getType()) &&
3320 "SCEVURemExpr operand types don't match!");
3322 // Short-circuit easy cases
3323 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3324 // If constant is one, the result is trivial
3325 if (RHSC->getValue()->isOne())
3326 return getZero(LHS->getType()); // X urem 1 --> 0
3328 // If constant is a power of two, fold into a zext(trunc(LHS)).
3329 if (RHSC->getAPInt().isPowerOf2()) {
3330 Type *FullTy = LHS->getType();
3332 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3333 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3337 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3338 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3339 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3340 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3343 /// Get a canonical unsigned division expression, or something simpler if
3345 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
3347 assert(!LHS->getType()->isPointerTy() &&
3348 "SCEVUDivExpr operand can't be pointer!");
3349 assert(LHS->getType() == RHS->getType() &&
3350 "SCEVUDivExpr operand types don't match!");
3352 FoldingSetNodeID ID;
3353 ID.AddInteger(scUDivExpr);
3357 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3361 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3362 if (LHSC->getValue()->isZero())
3365 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3366 if (RHSC->getValue()->isOne())
3367 return LHS; // X udiv 1 --> x
3368 // If the denominator is zero, the result of the udiv is undefined. Don't
3369 // try to analyze it, because the resolution chosen here may differ from
3370 // the resolution chosen in other parts of the compiler.
3371 if (!RHSC->getValue()->isZero()) {
3372 // Determine if the division can be folded into the operands of
3374 // TODO: Generalize this to non-constants by using known-bits information.
3375 Type *Ty = LHS->getType();
3376 unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3377 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3378 // For non-power-of-two values, effectively round the value up to the
3379 // nearest power of two.
3380 if (!RHSC->getAPInt().isPowerOf2())
3382 IntegerType *ExtTy =
3383 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3384 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3385 if (const SCEVConstant *Step =
3386 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3387 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3388 const APInt &StepInt = Step->getAPInt();
3389 const APInt &DivInt = RHSC->getAPInt();
3390 if (!StepInt.urem(DivInt) &&
3391 getZeroExtendExpr(AR, ExtTy) ==
3392 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3393 getZeroExtendExpr(Step, ExtTy),
3394 AR->getLoop(), SCEV::FlagAnyWrap)) {
3395 SmallVector<const SCEV *, 4> Operands;
3396 for (const SCEV *Op : AR->operands())
3397 Operands.push_back(getUDivExpr(Op, RHS));
3398 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3400 /// Get a canonical UDivExpr for a recurrence.
3401 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3402 // We can currently only fold X%N if X is constant.
3403 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3404 if (StartC && !DivInt.urem(StepInt) &&
3405 getZeroExtendExpr(AR, ExtTy) ==
3406 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3407 getZeroExtendExpr(Step, ExtTy),
3408 AR->getLoop(), SCEV::FlagAnyWrap)) {
3409 const APInt &StartInt = StartC->getAPInt();
3410 const APInt &StartRem = StartInt.urem(StepInt);
3411 if (StartRem != 0) {
3412 const SCEV *NewLHS =
3413 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3414 AR->getLoop(), SCEV::FlagNW);
3415 if (LHS != NewLHS) {
3418 // Reset the ID to include the new LHS, and check if it is
3421 ID.AddInteger(scUDivExpr);
3425 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3431 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3432 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3433 SmallVector<const SCEV *, 4> Operands;
3434 for (const SCEV *Op : M->operands())
3435 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3436 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3437 // Find an operand that's safely divisible.
3438 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3439 const SCEV *Op = M->getOperand(i);
3440 const SCEV *Div = getUDivExpr(Op, RHSC);
3441 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3442 Operands = SmallVector<const SCEV *, 4>(M->operands());
3444 return getMulExpr(Operands);
3449 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3450 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3451 if (auto *DivisorConstant =
3452 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3453 bool Overflow = false;
3455 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3457 return getConstant(RHSC->getType(), 0, false);
3459 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3463 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3464 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3465 SmallVector<const SCEV *, 4> Operands;
3466 for (const SCEV *Op : A->operands())
3467 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3468 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3470 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3471 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3472 if (isa<SCEVUDivExpr>(Op) ||
3473 getMulExpr(Op, RHS) != A->getOperand(i))
3475 Operands.push_back(Op);
3477 if (Operands.size() == A->getNumOperands())
3478 return getAddExpr(Operands);
3482 // Fold if both operands are constant.
3483 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3484 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3488 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3489 // changes). Make sure we get a new one.
3491 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3492 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3494 UniqueSCEVs.InsertNode(S, IP);
3495 registerUser(S, {LHS, RHS});
3499 APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3500 APInt A = C1->getAPInt().abs();
3501 APInt B = C2->getAPInt().abs();
3502 uint32_t ABW = A.getBitWidth();
3503 uint32_t BBW = B.getBitWidth();
3510 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3513 /// Get a canonical unsigned division expression, or something simpler if
3514 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3515 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3516 /// it's not exact because the udiv may be clearing bits.
3517 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
3519 // TODO: we could try to find factors in all sorts of things, but for now we
3520 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3521 // end of this file for inspiration.
3523 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3524 if (!Mul || !Mul->hasNoUnsignedWrap())
3525 return getUDivExpr(LHS, RHS);
3527 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3528 // If the mulexpr multiplies by a constant, then that constant must be the
3529 // first element of the mulexpr.
3530 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3531 if (LHSCst == RHSCst) {
3532 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3533 return getMulExpr(Operands);
3536 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3537 // that there's a factor provided by one of the other terms. We need to
3539 APInt Factor = gcd(LHSCst, RHSCst);
3540 if (!Factor.isIntN(1)) {
3542 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3544 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3545 SmallVector<const SCEV *, 2> Operands;
3546 Operands.push_back(LHSCst);
3547 Operands.append(Mul->op_begin() + 1, Mul->op_end());
3548 LHS = getMulExpr(Operands);
3550 Mul = dyn_cast<SCEVMulExpr>(LHS);
3552 return getUDivExactExpr(LHS, RHS);
3557 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3558 if (Mul->getOperand(i) == RHS) {
3559 SmallVector<const SCEV *, 2> Operands;
3560 Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3561 Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3562 return getMulExpr(Operands);
3566 return getUDivExpr(LHS, RHS);
3569 /// Get an add recurrence expression for the specified loop. Simplify the
3570 /// expression as much as possible.
3571 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3573 SCEV::NoWrapFlags Flags) {
3574 SmallVector<const SCEV *, 4> Operands;
3575 Operands.push_back(Start);
3576 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3577 if (StepChrec->getLoop() == L) {
3578 Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3579 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3582 Operands.push_back(Step);
3583 return getAddRecExpr(Operands, L, Flags);
3586 /// Get an add recurrence expression for the specified loop. Simplify the
3587 /// expression as much as possible.
3589 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
3590 const Loop *L, SCEV::NoWrapFlags Flags) {
3591 if (Operands.size() == 1) return Operands[0];
3593 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3594 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3595 assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
3596 "SCEVAddRecExpr operand types don't match!");
3597 assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3599 for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3600 assert(isLoopInvariant(Operands[i], L) &&
3601 "SCEVAddRecExpr operand is not loop-invariant!");
3604 if (Operands.back()->isZero()) {
3605 Operands.pop_back();
3606 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3609 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3610 // use that information to infer NUW and NSW flags. However, computing a
3611 // BE count requires calling getAddRecExpr, so we may not yet have a
3612 // meaningful BE count at this point (and if we don't, we'd be stuck
3613 // with a SCEVCouldNotCompute as the cached BE count).
3615 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3617 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3618 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3619 const Loop *NestedLoop = NestedAR->getLoop();
3620 if (L->contains(NestedLoop)
3621 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3622 : (!NestedLoop->contains(L) &&
3623 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3624 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3625 Operands[0] = NestedAR->getStart();
3626 // AddRecs require their operands be loop-invariant with respect to their
3627 // loops. Don't perform this transformation if it would break this
3629 bool AllInvariant = all_of(
3630 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3633 // Create a recurrence for the outer loop with the same step size.
3635 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3636 // inner recurrence has the same property.
3637 SCEV::NoWrapFlags OuterFlags =
3638 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3640 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3641 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3642 return isLoopInvariant(Op, NestedLoop);
3646 // Ok, both add recurrences are valid after the transformation.
3648 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3649 // the outer recurrence has the same property.
3650 SCEV::NoWrapFlags InnerFlags =
3651 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3652 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3655 // Reset Operands to its original state.
3656 Operands[0] = NestedAR;
3660 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3661 // already have one, otherwise create a new one.
3662 return getOrCreateAddRecExpr(Operands, L, Flags);
3666 ScalarEvolution::getGEPExpr(GEPOperator *GEP,
3667 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3668 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3669 // getSCEV(Base)->getType() has the same address space as Base->getType()
3670 // because SCEV::getType() preserves the address space.
3671 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3672 const bool AssumeInBoundsFlags = [&]() {
3673 if (!GEP->isInBounds())
3676 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3677 // but to do that, we have to ensure that said flag is valid in the entire
3678 // defined scope of the SCEV.
3679 auto *GEPI = dyn_cast<Instruction>(GEP);
3680 // TODO: non-instructions have global scope. We might be able to prove
3681 // some global scope cases
3682 return GEPI && isSCEVExprNeverPoison(GEPI);
3685 SCEV::NoWrapFlags OffsetWrap =
3686 AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3688 Type *CurTy = GEP->getType();
3689 bool FirstIter = true;
3690 SmallVector<const SCEV *, 4> Offsets;
3691 for (const SCEV *IndexExpr : IndexExprs) {
3692 // Compute the (potentially symbolic) offset in bytes for this index.
3693 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3694 // For a struct, add the member offset.
3695 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3696 unsigned FieldNo = Index->getZExtValue();
3697 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3698 Offsets.push_back(FieldOffset);
3700 // Update CurTy to the type of the field at Index.
3701 CurTy = STy->getTypeAtIndex(Index);
3703 // Update CurTy to its element type.
3705 assert(isa<PointerType>(CurTy) &&
3706 "The first index of a GEP indexes a pointer");
3707 CurTy = GEP->getSourceElementType();
3710 CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3712 // For an array, add the element offset, explicitly scaled.
3713 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3714 // Getelementptr indices are signed.
3715 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3717 // Multiply the index by the element size to compute the element offset.
3718 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3719 Offsets.push_back(LocalOffset);
3723 // Handle degenerate case of GEP without offsets.
3724 if (Offsets.empty())
3727 // Add the offsets together, assuming nsw if inbounds.
3728 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3729 // Add the base address and the offset. We cannot use the nsw flag, as the
3730 // base address is unsigned. However, if we know that the offset is
3731 // non-negative, we can use nuw.
3732 SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3733 ? SCEV::FlagNUW : SCEV::FlagAnyWrap;
3734 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3735 assert(BaseExpr->getType() == GEPExpr->getType() &&
3736 "GEP should not change type mid-flight.");
3740 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3741 ArrayRef<const SCEV *> Ops) {
3742 FoldingSetNodeID ID;
3743 ID.AddInteger(SCEVType);
3744 for (const SCEV *Op : Ops)
3747 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3750 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3751 SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3752 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3755 const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
3756 SmallVectorImpl<const SCEV *> &Ops) {
3757 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3758 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3759 if (Ops.size() == 1) return Ops[0];
3761 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3762 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3763 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3764 "Operand types don't match!");
3765 assert(Ops[0]->getType()->isPointerTy() ==
3766 Ops[i]->getType()->isPointerTy() &&
3767 "min/max should be consistently pointerish");
3771 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3772 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3774 // Sort by complexity, this groups all similar expression types together.
3775 GroupByComplexity(Ops, &LI, DT);
3777 // Check if we have created the same expression before.
3778 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3782 // If there are any constants, fold them together.
3784 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3786 assert(Idx < Ops.size());
3787 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3788 if (Kind == scSMaxExpr)
3789 return APIntOps::smax(LHS, RHS);
3790 else if (Kind == scSMinExpr)
3791 return APIntOps::smin(LHS, RHS);
3792 else if (Kind == scUMaxExpr)
3793 return APIntOps::umax(LHS, RHS);
3794 else if (Kind == scUMinExpr)
3795 return APIntOps::umin(LHS, RHS);
3796 llvm_unreachable("Unknown SCEV min/max opcode");
3799 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3800 // We found two constants, fold them together!
3801 ConstantInt *Fold = ConstantInt::get(
3802 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3803 Ops[0] = getConstant(Fold);
3804 Ops.erase(Ops.begin()+1); // Erase the folded element
3805 if (Ops.size() == 1) return Ops[0];
3806 LHSC = cast<SCEVConstant>(Ops[0]);
3809 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3810 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3812 if (IsMax ? IsMinV : IsMaxV) {
3813 // If we are left with a constant minimum(/maximum)-int, strip it off.
3814 Ops.erase(Ops.begin());
3816 } else if (IsMax ? IsMaxV : IsMinV) {
3817 // If we have a max(/min) with a constant maximum(/minimum)-int,
3818 // it will always be the extremum.
3822 if (Ops.size() == 1) return Ops[0];
3825 // Find the first operation of the same kind
3826 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3829 // Check to see if one of the operands is of the same kind. If so, expand its
3830 // operands onto our operand list, and recurse to simplify.
3831 if (Idx < Ops.size()) {
3832 bool DeletedAny = false;
3833 while (Ops[Idx]->getSCEVType() == Kind) {
3834 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3835 Ops.erase(Ops.begin()+Idx);
3836 Ops.append(SMME->op_begin(), SMME->op_end());
3841 return getMinMaxExpr(Kind, Ops);
3844 // Okay, check to see if the same value occurs in the operand list twice. If
3845 // so, delete one. Since we sorted the list, these values are required to
3847 llvm::CmpInst::Predicate GEPred =
3848 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
3849 llvm::CmpInst::Predicate LEPred =
3850 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
3851 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3852 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3853 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3854 if (Ops[i] == Ops[i + 1] ||
3855 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3856 // X op Y op Y --> X op Y
3857 // X op Y --> X, if we know X, Y are ordered appropriately
3858 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3861 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3863 // X op Y --> Y, if we know X, Y are ordered appropriately
3864 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3870 if (Ops.size() == 1) return Ops[0];
3872 assert(!Ops.empty() && "Reduced smax down to nothing!");
3874 // Okay, it looks like we really DO need an expr. Check to see if we
3875 // already have one, otherwise create a new one.
3876 FoldingSetNodeID ID;
3877 ID.AddInteger(Kind);
3878 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3879 ID.AddPointer(Ops[i]);
3881 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3883 return ExistingSCEV;
3884 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3885 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3886 SCEV *S = new (SCEVAllocator)
3887 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3889 UniqueSCEVs.InsertNode(S, IP);
3890 registerUser(S, Ops);
3896 class SCEVSequentialMinMaxDeduplicatingVisitor final
3897 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3898 Optional<const SCEV *>> {
3899 using RetVal = Optional<const SCEV *>;
3900 using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
3902 ScalarEvolution &SE;
3903 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3904 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3905 SmallPtrSet<const SCEV *, 16> SeenOps;
3907 bool canRecurseInto(SCEVTypes Kind) const {
3908 // We can only recurse into the SCEV expression of the same effective type
3909 // as the type of our root SCEV expression.
3910 return RootKind == Kind || NonSequentialRootKind == Kind;
3913 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3914 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3915 "Only for min/max expressions.");
3916 SCEVTypes Kind = S->getSCEVType();
3918 if (!canRecurseInto(Kind))
3921 auto *NAry = cast<SCEVNAryExpr>(S);
3922 SmallVector<const SCEV *> NewOps;
3924 visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
3931 return isa<SCEVSequentialMinMaxExpr>(S)
3932 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3933 : SE.getMinMaxExpr(Kind, NewOps);
3936 RetVal visit(const SCEV *S) {
3937 // Has the whole operand been seen already?
3938 if (!SeenOps.insert(S).second)
3940 return Base::visit(S);
3944 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
3946 : SE(SE), RootKind(RootKind),
3947 NonSequentialRootKind(
3948 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
3951 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
3952 SmallVectorImpl<const SCEV *> &NewOps) {
3953 bool Changed = false;
3954 SmallVector<const SCEV *> Ops;
3955 Ops.reserve(OrigOps.size());
3957 for (const SCEV *Op : OrigOps) {
3958 RetVal NewOp = visit(Op);
3962 Ops.emplace_back(*NewOp);
3966 NewOps = std::move(Ops);
3970 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
3972 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
3974 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
3976 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
3978 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
3980 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
3982 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
3984 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
3986 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
3988 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
3989 return visitAnyMinMaxExpr(Expr);
3992 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
3993 return visitAnyMinMaxExpr(Expr);
3996 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
3997 return visitAnyMinMaxExpr(Expr);
4000 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4001 return visitAnyMinMaxExpr(Expr);
4004 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4005 return visitAnyMinMaxExpr(Expr);
4008 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4010 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4015 /// Return true if V is poison given that AssumedPoison is already poison.
4016 static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4017 // The only way poison may be introduced in a SCEV expression is from a
4018 // poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4019 // not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4020 // introduce poison -- they encode guaranteed, non-speculated knowledge.
4022 // Additionally, all SCEV nodes propagate poison from inputs to outputs,
4023 // with the notable exception of umin_seq, where only poison from the first
4024 // operand is (unconditionally) propagated.
4025 struct SCEVPoisonCollector {
4026 bool LookThroughSeq;
4027 SmallPtrSet<const SCEV *, 4> MaybePoison;
4028 SCEVPoisonCollector(bool LookThroughSeq) : LookThroughSeq(LookThroughSeq) {}
4030 bool follow(const SCEV *S) {
4031 // TODO: We can always follow the first operand, but the SCEVTraversal
4032 // API doesn't support this.
4033 if (!LookThroughSeq && isa<SCEVSequentialMinMaxExpr>(S))
4036 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4037 if (!isGuaranteedNotToBePoison(SU->getValue()))
4038 MaybePoison.insert(S);
4042 bool isDone() const { return false; }
4045 // First collect all SCEVs that might result in AssumedPoison to be poison.
4046 // We need to look through umin_seq here, because we want to find all SCEVs
4047 // that *might* result in poison, not only those that are *required* to.
4048 SCEVPoisonCollector PC1(/* LookThroughSeq */ true);
4049 visitAll(AssumedPoison, PC1);
4051 // AssumedPoison is never poison. As the assumption is false, the implication
4052 // is true. Don't bother walking the other SCEV in this case.
4053 if (PC1.MaybePoison.empty())
4056 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4057 // as well. We cannot look through umin_seq here, as its argument only *may*
4058 // make the result poison.
4059 SCEVPoisonCollector PC2(/* LookThroughSeq */ false);
4062 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4063 // it will also make S poison by being part of PC2.MaybePoison.
4064 return all_of(PC1.MaybePoison,
4065 [&](const SCEV *S) { return PC2.MaybePoison.contains(S); });
4069 ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
4070 SmallVectorImpl<const SCEV *> &Ops) {
4071 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4072 "Not a SCEVSequentialMinMaxExpr!");
4073 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4074 if (Ops.size() == 1)
4077 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4078 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4079 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4080 "Operand types don't match!");
4081 assert(Ops[0]->getType()->isPointerTy() ==
4082 Ops[i]->getType()->isPointerTy() &&
4083 "min/max should be consistently pointerish");
4087 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4088 // so we can *NOT* do any kind of sorting of the expressions!
4090 // Check if we have created the same expression before.
4091 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4094 // FIXME: there are *some* simplifications that we can do here.
4096 // Keep only the first instance of an operand.
4098 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4099 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4101 return getSequentialMinMaxExpr(Kind, Ops);
4104 // Check to see if one of the operands is of the same kind. If so, expand its
4105 // operands onto our operand list, and recurse to simplify.
4108 bool DeletedAny = false;
4109 while (Idx < Ops.size()) {
4110 if (Ops[Idx]->getSCEVType() != Kind) {
4114 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4115 Ops.erase(Ops.begin() + Idx);
4116 Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end());
4121 return getSequentialMinMaxExpr(Kind, Ops);
4124 const SCEV *SaturationPoint;
4125 ICmpInst::Predicate Pred;
4127 case scSequentialUMinExpr:
4128 SaturationPoint = getZero(Ops[0]->getType());
4129 Pred = ICmpInst::ICMP_ULE;
4132 llvm_unreachable("Not a sequential min/max type.");
4135 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4136 // We can replace %x umin_seq %y with %x umin %y if either:
4137 // * %y being poison implies %x is also poison.
4138 // * %x cannot be the saturating value (e.g. zero for umin).
4139 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4140 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4142 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4143 Ops[i - 1] = getMinMaxExpr(
4144 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
4146 Ops.erase(Ops.begin() + i);
4147 return getSequentialMinMaxExpr(Kind, Ops);
4149 // Fold %x umin_seq %y to %x if %x ule %y.
4150 // TODO: We might be able to prove the predicate for a later operand.
4151 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4152 Ops.erase(Ops.begin() + i);
4153 return getSequentialMinMaxExpr(Kind, Ops);
4157 // Okay, it looks like we really DO need an expr. Check to see if we
4158 // already have one, otherwise create a new one.
4159 FoldingSetNodeID ID;
4160 ID.AddInteger(Kind);
4161 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4162 ID.AddPointer(Ops[i]);
4164 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4166 return ExistingSCEV;
4168 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4169 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4170 SCEV *S = new (SCEVAllocator)
4171 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4173 UniqueSCEVs.InsertNode(S, IP);
4174 registerUser(S, Ops);
4178 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4179 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4180 return getSMaxExpr(Ops);
4183 const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4184 return getMinMaxExpr(scSMaxExpr, Ops);
4187 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4188 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4189 return getUMaxExpr(Ops);
4192 const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4193 return getMinMaxExpr(scUMaxExpr, Ops);
4196 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
4198 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4199 return getSMinExpr(Ops);
4202 const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
4203 return getMinMaxExpr(scSMinExpr, Ops);
4206 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4208 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4209 return getUMinExpr(Ops, Sequential);
4212 const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
4214 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4215 : getMinMaxExpr(scUMinExpr, Ops);
4219 ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy,
4220 ScalableVectorType *ScalableTy) {
4221 Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
4222 Constant *One = ConstantInt::get(IntTy, 1);
4223 Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
4224 // Note that the expression we created is the final expression, we don't
4225 // want to simplify it any further Also, if we call a normal getSCEV(),
4226 // we'll end up in an endless recursion. So just create an SCEVUnknown.
4227 return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
4230 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
4231 if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
4232 return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
4233 // We can bypass creating a target-independent constant expression and then
4234 // folding it back into a ConstantInt. This is just a compile-time
4236 return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4239 const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) {
4240 if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
4241 return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
4242 // We can bypass creating a target-independent constant expression and then
4243 // folding it back into a ConstantInt. This is just a compile-time
4245 return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4248 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
4251 // We can bypass creating a target-independent constant expression and then
4252 // folding it back into a ConstantInt. This is just a compile-time
4255 IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
4258 const SCEV *ScalarEvolution::getUnknown(Value *V) {
4259 // Don't attempt to do anything other than create a SCEVUnknown object
4260 // here. createSCEV only calls getUnknown after checking for all other
4261 // interesting possibilities, and any other code that calls getUnknown
4262 // is doing so in order to hide a value from SCEV canonicalization.
4264 FoldingSetNodeID ID;
4265 ID.AddInteger(scUnknown);
4268 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4269 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4270 "Stale SCEVUnknown in uniquing map!");
4273 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4275 FirstUnknown = cast<SCEVUnknown>(S);
4276 UniqueSCEVs.InsertNode(S, IP);
4280 //===----------------------------------------------------------------------===//
4281 // Basic SCEV Analysis and PHI Idiom Recognition Code
4284 /// Test if values of the given type are analyzable within the SCEV
4285 /// framework. This primarily includes integer types, and it can optionally
4286 /// include pointer types if the ScalarEvolution class has access to
4287 /// target-specific information.
4288 bool ScalarEvolution::isSCEVable(Type *Ty) const {
4289 // Integers and pointers are always SCEVable.
4290 return Ty->isIntOrPtrTy();
4293 /// Return the size in bits of the specified type, for which isSCEVable must
4295 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
4296 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4297 if (Ty->isPointerTy())
4298 return getDataLayout().getIndexTypeSizeInBits(Ty);
4299 return getDataLayout().getTypeSizeInBits(Ty);
4302 /// Return a type with the same bitwidth as the given type and which represents
4303 /// how SCEV will treat the given type, for which isSCEVable must return
4304 /// true. For pointer types, this is the pointer index sized integer type.
4305 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
4306 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4308 if (Ty->isIntegerTy())
4311 // The only other support type is pointer.
4312 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4313 return getDataLayout().getIndexType(Ty);
4316 Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const {
4317 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4320 bool ScalarEvolution::instructionCouldExistWitthOperands(const SCEV *A,
4322 /// For a valid use point to exist, the defining scope of one operand
4323 /// must dominate the other.
4324 bool PreciseA, PreciseB;
4325 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4326 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4327 if (!PreciseA || !PreciseB)
4330 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4331 DT.dominates(ScopeB, ScopeA);
4335 const SCEV *ScalarEvolution::getCouldNotCompute() {
4336 return CouldNotCompute.get();
4339 bool ScalarEvolution::checkValidity(const SCEV *S) const {
4340 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4341 auto *SU = dyn_cast<SCEVUnknown>(S);
4342 return SU && SU->getValue() == nullptr;
4345 return !ContainsNulls;
4348 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
4349 HasRecMapType::iterator I = HasRecMap.find(S);
4350 if (I != HasRecMap.end())
4354 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4355 HasRecMap.insert({S, FoundAddRec});
4359 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4360 /// by the value and offset from any ValueOffsetPair in the set.
4361 ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4362 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4363 if (SI == ExprValueMap.end())
4366 if (VerifySCEVMap) {
4367 // Check there is no dangling Value in the set returned.
4368 for (Value *V : SI->second)
4369 assert(ValueExprMap.count(V));
4372 return SI->second.getArrayRef();
4375 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4376 /// cannot be used separately. eraseValueFromMap should be used to remove
4377 /// V from ValueExprMap and ExprValueMap at the same time.
4378 void ScalarEvolution::eraseValueFromMap(Value *V) {
4379 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4380 if (I != ValueExprMap.end()) {
4381 auto EVIt = ExprValueMap.find(I->second);
4382 bool Removed = EVIt->second.remove(V);
4384 assert(Removed && "Value not in ExprValueMap?");
4385 ValueExprMap.erase(I);
4389 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4390 // A recursive query may have already computed the SCEV. It should be
4391 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4392 // inferred nowrap flags.
4393 auto It = ValueExprMap.find_as(V);
4394 if (It == ValueExprMap.end()) {
4395 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4396 ExprValueMap[S].insert(V);
4400 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4401 /// create a new one.
4402 const SCEV *ScalarEvolution::getSCEV(Value *V) {
4403 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4405 if (const SCEV *S = getExistingSCEV(V))
4407 return createSCEVIter(V);
4410 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4411 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4413 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4414 if (I != ValueExprMap.end()) {
4415 const SCEV *S = I->second;
4416 assert(checkValidity(S) &&
4417 "existing SCEV has not been properly invalidated");
4423 /// Return a SCEV corresponding to -V = -1*V
4424 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
4425 SCEV::NoWrapFlags Flags) {
4426 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4428 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4430 Type *Ty = V->getType();
4431 Ty = getEffectiveSCEVType(Ty);
4432 return getMulExpr(V, getMinusOne(Ty), Flags);
4435 /// If Expr computes ~A, return A else return nullptr
4436 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4437 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4438 if (!Add || Add->getNumOperands() != 2 ||
4439 !Add->getOperand(0)->isAllOnesValue())
4442 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4443 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4444 !AddRHS->getOperand(0)->isAllOnesValue())
4447 return AddRHS->getOperand(1);
4450 /// Return a SCEV corresponding to ~V = -1-V
4451 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
4452 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4454 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4456 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4458 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4459 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4460 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4461 SmallVector<const SCEV *, 2> MatchedOperands;
4462 for (const SCEV *Operand : MME->operands()) {
4463 const SCEV *Matched = MatchNotExpr(Operand);
4465 return (const SCEV *)nullptr;
4466 MatchedOperands.push_back(Matched);
4468 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4471 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4475 Type *Ty = V->getType();
4476 Ty = getEffectiveSCEVType(Ty);
4477 return getMinusSCEV(getMinusOne(Ty), V);
4480 const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
4481 assert(P->getType()->isPointerTy());
4483 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4484 // The base of an AddRec is the first operand.
4485 SmallVector<const SCEV *> Ops{AddRec->operands()};
4486 Ops[0] = removePointerBase(Ops[0]);
4487 // Don't try to transfer nowrap flags for now. We could in some cases
4488 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4489 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4491 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4492 // The base of an Add is the pointer operand.
4493 SmallVector<const SCEV *> Ops{Add->operands()};
4494 const SCEV **PtrOp = nullptr;
4495 for (const SCEV *&AddOp : Ops) {
4496 if (AddOp->getType()->isPointerTy()) {
4497 assert(!PtrOp && "Cannot have multiple pointer ops");
4501 *PtrOp = removePointerBase(*PtrOp);
4502 // Don't try to transfer nowrap flags for now. We could in some cases
4503 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4504 return getAddExpr(Ops);
4506 // Any other expression must be a pointer base.
4507 return getZero(P->getType());
4510 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4511 SCEV::NoWrapFlags Flags,
4513 // Fast path: X - X --> 0.
4515 return getZero(LHS->getType());
4517 // If we subtract two pointers with different pointer bases, bail.
4518 // Eventually, we're going to add an assertion to getMulExpr that we
4519 // can't multiply by a pointer.
4520 if (RHS->getType()->isPointerTy()) {
4521 if (!LHS->getType()->isPointerTy() ||
4522 getPointerBase(LHS) != getPointerBase(RHS))
4523 return getCouldNotCompute();
4524 LHS = removePointerBase(LHS);
4525 RHS = removePointerBase(RHS);
4528 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4529 // makes it so that we cannot make much use of NUW.
4530 auto AddFlags = SCEV::FlagAnyWrap;
4531 const bool RHSIsNotMinSigned =
4532 !getSignedRangeMin(RHS).isMinSignedValue();
4533 if (hasFlags(Flags, SCEV::FlagNSW)) {
4534 // Let M be the minimum representable signed value. Then (-1)*RHS
4535 // signed-wraps if and only if RHS is M. That can happen even for
4536 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4537 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4538 // (-1)*RHS, we need to prove that RHS != M.
4540 // If LHS is non-negative and we know that LHS - RHS does not
4541 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4542 // either by proving that RHS > M or that LHS >= 0.
4543 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4544 AddFlags = SCEV::FlagNSW;
4548 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4549 // RHS is NSW and LHS >= 0.
4551 // The difficulty here is that the NSW flag may have been proven
4552 // relative to a loop that is to be found in a recurrence in LHS and
4553 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4554 // larger scope than intended.
4555 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4557 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4560 const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
4562 Type *SrcTy = V->getType();
4563 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4564 "Cannot truncate or zero extend with non-integer arguments!");
4565 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4566 return V; // No conversion
4567 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4568 return getTruncateExpr(V, Ty, Depth);
4569 return getZeroExtendExpr(V, Ty, Depth);
4572 const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
4574 Type *SrcTy = V->getType();
4575 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4576 "Cannot truncate or zero extend with non-integer arguments!");
4577 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4578 return V; // No conversion
4579 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4580 return getTruncateExpr(V, Ty, Depth);
4581 return getSignExtendExpr(V, Ty, Depth);
4585 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
4586 Type *SrcTy = V->getType();
4587 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4588 "Cannot noop or zero extend with non-integer arguments!");
4589 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4590 "getNoopOrZeroExtend cannot truncate!");
4591 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4592 return V; // No conversion
4593 return getZeroExtendExpr(V, Ty);
4597 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
4598 Type *SrcTy = V->getType();
4599 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4600 "Cannot noop or sign extend with non-integer arguments!");
4601 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4602 "getNoopOrSignExtend cannot truncate!");
4603 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4604 return V; // No conversion
4605 return getSignExtendExpr(V, Ty);
4609 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
4610 Type *SrcTy = V->getType();
4611 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4612 "Cannot noop or any extend with non-integer arguments!");
4613 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4614 "getNoopOrAnyExtend cannot truncate!");
4615 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4616 return V; // No conversion
4617 return getAnyExtendExpr(V, Ty);
4621 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
4622 Type *SrcTy = V->getType();
4623 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4624 "Cannot truncate or noop with non-integer arguments!");
4625 assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4626 "getTruncateOrNoop cannot extend!");
4627 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4628 return V; // No conversion
4629 return getTruncateExpr(V, Ty);
4632 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
4634 const SCEV *PromotedLHS = LHS;
4635 const SCEV *PromotedRHS = RHS;
4637 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4638 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4640 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4642 return getUMaxExpr(PromotedLHS, PromotedRHS);
4645 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
4648 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4649 return getUMinFromMismatchedTypes(Ops, Sequential);
4653 ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
4655 assert(!Ops.empty() && "At least one operand must be!");
4657 if (Ops.size() == 1)
4660 // Find the max type first.
4661 Type *MaxType = nullptr;
4664 MaxType = getWiderType(MaxType, S->getType());
4666 MaxType = S->getType();
4667 assert(MaxType && "Failed to find maximum type!");
4669 // Extend all ops to max type.
4670 SmallVector<const SCEV *, 2> PromotedOps;
4672 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4675 return getUMinExpr(PromotedOps, Sequential);
4678 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
4679 // A pointer operand may evaluate to a nonpointer expression, such as null.
4680 if (!V->getType()->isPointerTy())
4684 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4685 V = AddRec->getStart();
4686 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4687 const SCEV *PtrOp = nullptr;
4688 for (const SCEV *AddOp : Add->operands()) {
4689 if (AddOp->getType()->isPointerTy()) {
4690 assert(!PtrOp && "Cannot have multiple pointer ops");
4694 assert(PtrOp && "Must have pointer op");
4696 } else // Not something we can look further into.
4701 /// Push users of the given Instruction onto the given Worklist.
4702 static void PushDefUseChildren(Instruction *I,
4703 SmallVectorImpl<Instruction *> &Worklist,
4704 SmallPtrSetImpl<Instruction *> &Visited) {
4705 // Push the def-use children onto the Worklist stack.
4706 for (User *U : I->users()) {
4707 auto *UserInsn = cast<Instruction>(U);
4708 if (Visited.insert(UserInsn).second)
4709 Worklist.push_back(UserInsn);
4715 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4716 /// expression in case its Loop is L. If it is not L then
4717 /// if IgnoreOtherLoops is true then use AddRec itself
4718 /// otherwise rewrite cannot be done.
4719 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4720 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4722 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4723 bool IgnoreOtherLoops = true) {
4724 SCEVInitRewriter Rewriter(L, SE);
4725 const SCEV *Result = Rewriter.visit(S);
4726 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4727 return SE.getCouldNotCompute();
4728 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4729 ? SE.getCouldNotCompute()
4733 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4734 if (!SE.isLoopInvariant(Expr, L))
4735 SeenLoopVariantSCEVUnknown = true;
4739 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4740 // Only re-write AddRecExprs for this loop.
4741 if (Expr->getLoop() == L)
4742 return Expr->getStart();
4743 SeenOtherLoops = true;
4747 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4749 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4752 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4753 : SCEVRewriteVisitor(SE), L(L) {}
4756 bool SeenLoopVariantSCEVUnknown = false;
4757 bool SeenOtherLoops = false;
4760 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4761 /// increment expression in case its Loop is L. If it is not L then
4762 /// use AddRec itself.
4763 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4764 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4766 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4767 SCEVPostIncRewriter Rewriter(L, SE);
4768 const SCEV *Result = Rewriter.visit(S);
4769 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4770 ? SE.getCouldNotCompute()
4774 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4775 if (!SE.isLoopInvariant(Expr, L))
4776 SeenLoopVariantSCEVUnknown = true;
4780 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4781 // Only re-write AddRecExprs for this loop.
4782 if (Expr->getLoop() == L)
4783 return Expr->getPostIncExpr(SE);
4784 SeenOtherLoops = true;
4788 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4790 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4793 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4794 : SCEVRewriteVisitor(SE), L(L) {}
4797 bool SeenLoopVariantSCEVUnknown = false;
4798 bool SeenOtherLoops = false;
4801 /// This class evaluates the compare condition by matching it against the
4802 /// condition of loop latch. If there is a match we assume a true value
4803 /// for the condition while building SCEV nodes.
4804 class SCEVBackedgeConditionFolder
4805 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4807 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4808 ScalarEvolution &SE) {
4809 bool IsPosBECond = false;
4810 Value *BECond = nullptr;
4811 if (BasicBlock *Latch = L->getLoopLatch()) {
4812 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4813 if (BI && BI->isConditional()) {
4814 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4815 "Both outgoing branches should not target same header!");
4816 BECond = BI->getCondition();
4817 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4822 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4823 return Rewriter.visit(S);
4826 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4827 const SCEV *Result = Expr;
4828 bool InvariantF = SE.isLoopInvariant(Expr, L);
4831 Instruction *I = cast<Instruction>(Expr->getValue());
4832 switch (I->getOpcode()) {
4833 case Instruction::Select: {
4834 SelectInst *SI = cast<SelectInst>(I);
4835 Optional<const SCEV *> Res =
4836 compareWithBackedgeCondition(SI->getCondition());
4838 bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
4839 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4844 Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4846 Result = Res.getValue();
4855 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4856 bool IsPosBECond, ScalarEvolution &SE)
4857 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4858 IsPositiveBECond(IsPosBECond) {}
4860 Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4863 /// Loop back condition.
4864 Value *BackedgeCond = nullptr;
4865 /// Set to true if loop back is on positive branch condition.
4866 bool IsPositiveBECond;
4869 Optional<const SCEV *>
4870 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4872 // If value matches the backedge condition for loop latch,
4873 // then return a constant evolution node based on loopback
4875 if (BackedgeCond == IC)
4876 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4877 : SE.getZero(Type::getInt1Ty(SE.getContext()));
4881 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4883 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4884 ScalarEvolution &SE) {
4885 SCEVShiftRewriter Rewriter(L, SE);
4886 const SCEV *Result = Rewriter.visit(S);
4887 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4890 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4891 // Only allow AddRecExprs for this loop.
4892 if (!SE.isLoopInvariant(Expr, L))
4897 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4898 if (Expr->getLoop() == L && Expr->isAffine())
4899 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4904 bool isValid() { return Valid; }
4907 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
4908 : SCEVRewriteVisitor(SE), L(L) {}
4914 } // end anonymous namespace
4917 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4918 if (!AR->isAffine())
4919 return SCEV::FlagAnyWrap;
4921 using OBO = OverflowingBinaryOperator;
4923 SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
4925 if (!AR->hasNoSignedWrap()) {
4926 ConstantRange AddRecRange = getSignedRange(AR);
4927 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4929 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4930 Instruction::Add, IncRange, OBO::NoSignedWrap);
4931 if (NSWRegion.contains(AddRecRange))
4932 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
4935 if (!AR->hasNoUnsignedWrap()) {
4936 ConstantRange AddRecRange = getUnsignedRange(AR);
4937 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4939 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4940 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4941 if (NUWRegion.contains(AddRecRange))
4942 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
4949 ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4950 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
4952 if (AR->hasNoSignedWrap())
4955 if (!AR->isAffine())
4958 const SCEV *Step = AR->getStepRecurrence(*this);
4959 const Loop *L = AR->getLoop();
4961 // Check whether the backedge-taken count is SCEVCouldNotCompute.
4962 // Note that this serves two purposes: It filters out loops that are
4963 // simply not analyzable, and it covers the case where this code is
4964 // being called from within backedge-taken count analysis, such that
4965 // attempting to ask for the backedge-taken count would likely result
4966 // in infinite recursion. In the later case, the analysis code will
4967 // cope with a conservative value, and it will take care to purge
4968 // that value once it has finished.
4969 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
4971 // Normally, in the cases we can prove no-overflow via a
4972 // backedge guarding condition, we can also compute a backedge
4973 // taken count for the loop. The exceptions are assumptions and
4974 // guards present in the loop -- SCEV is not great at exploiting
4975 // these to compute max backedge taken counts, but can still use
4976 // these to prove lack of overflow. Use this fact to avoid
4977 // doing extra work that may not pay off.
4979 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
4980 AC.assumptions().empty())
4983 // If the backedge is guarded by a comparison with the pre-inc value the
4984 // addrec is safe. Also, if the entry is guarded by a comparison with the
4985 // start value and the backedge is guarded by a comparison with the post-inc
4986 // value, the addrec is safe.
4987 ICmpInst::Predicate Pred;
4988 const SCEV *OverflowLimit =
4989 getSignedOverflowLimitForStep(Step, &Pred, this);
4990 if (OverflowLimit &&
4991 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
4992 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
4993 Result = setFlags(Result, SCEV::FlagNSW);
4998 ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4999 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
5001 if (AR->hasNoUnsignedWrap())
5004 if (!AR->isAffine())
5007 const SCEV *Step = AR->getStepRecurrence(*this);
5008 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5009 const Loop *L = AR->getLoop();
5011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5012 // Note that this serves two purposes: It filters out loops that are
5013 // simply not analyzable, and it covers the case where this code is
5014 // being called from within backedge-taken count analysis, such that
5015 // attempting to ask for the backedge-taken count would likely result
5016 // in infinite recursion. In the later case, the analysis code will
5017 // cope with a conservative value, and it will take care to purge
5018 // that value once it has finished.
5019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5021 // Normally, in the cases we can prove no-overflow via a
5022 // backedge guarding condition, we can also compute a backedge
5023 // taken count for the loop. The exceptions are assumptions and
5024 // guards present in the loop -- SCEV is not great at exploiting
5025 // these to compute max backedge taken counts, but can still use
5026 // these to prove lack of overflow. Use this fact to avoid
5027 // doing extra work that may not pay off.
5029 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5030 AC.assumptions().empty())
5033 // If the backedge is guarded by a comparison with the pre-inc value the
5034 // addrec is safe. Also, if the entry is guarded by a comparison with the
5035 // start value and the backedge is guarded by a comparison with the post-inc
5036 // value, the addrec is safe.
5037 if (isKnownPositive(Step)) {
5038 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5039 getUnsignedRangeMax(Step));
5040 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
5041 isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
5042 Result = setFlags(Result, SCEV::FlagNUW);
5051 /// Represents an abstract binary operation. This may exist as a
5052 /// normal instruction or constant expression, or may have been
5053 /// derived from an expression tree.
5061 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5062 /// constant expression.
5063 Operator *Op = nullptr;
5065 explicit BinaryOp(Operator *Op)
5066 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5068 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5069 IsNSW = OBO->hasNoSignedWrap();
5070 IsNUW = OBO->hasNoUnsignedWrap();
5074 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5076 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5079 } // end anonymous namespace
5081 /// Try to map \p V into a BinaryOp, and return \c None on failure.
5082 static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
5083 auto *Op = dyn_cast<Operator>(V);
5087 // Implementation detail: all the cleverness here should happen without
5088 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5089 // SCEV expressions when possible, and we should not break that.
5091 switch (Op->getOpcode()) {
5092 case Instruction::Add:
5093 case Instruction::Sub:
5094 case Instruction::Mul:
5095 case Instruction::UDiv:
5096 case Instruction::URem:
5097 case Instruction::And:
5098 case Instruction::Or:
5099 case Instruction::AShr:
5100 case Instruction::Shl:
5101 return BinaryOp(Op);
5103 case Instruction::Xor:
5104 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5105 // If the RHS of the xor is a signmask, then this is just an add.
5106 // Instcombine turns add of signmask into xor as a strength reduction step.
5107 if (RHSC->getValue().isSignMask())
5108 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5109 // Binary `xor` is a bit-wise `add`.
5110 if (V->getType()->isIntegerTy(1))
5111 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5112 return BinaryOp(Op);
5114 case Instruction::LShr:
5115 // Turn logical shift right of a constant into a unsigned divide.
5116 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5117 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5119 // If the shift count is not less than the bitwidth, the result of
5120 // the shift is undefined. Don't try to analyze it, because the
5121 // resolution chosen here may differ from the resolution chosen in
5122 // other parts of the compiler.
5123 if (SA->getValue().ult(BitWidth)) {
5125 ConstantInt::get(SA->getContext(),
5126 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5127 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5130 return BinaryOp(Op);
5132 case Instruction::ExtractValue: {
5133 auto *EVI = cast<ExtractValueInst>(Op);
5134 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5137 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5141 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5142 bool Signed = WO->isSigned();
5143 // TODO: Should add nuw/nsw flags for mul as well.
5144 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5145 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5147 // Now that we know that all uses of the arithmetic-result component of
5148 // CI are guarded by the overflow check, we can go ahead and pretend
5149 // that the arithmetic is non-overflowing.
5150 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5151 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5158 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5159 // semantics as a Sub, return a binary sub expression.
5160 if (auto *II = dyn_cast<IntrinsicInst>(V))
5161 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5162 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5167 /// Helper function to createAddRecFromPHIWithCasts. We have a phi
5168 /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5169 /// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5170 /// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5171 /// follows one of the following patterns:
5172 /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5173 /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5174 /// If the SCEV expression of \p Op conforms with one of the expected patterns
5175 /// we return the type of the truncation operation, and indicate whether the
5176 /// truncated type should be treated as signed/unsigned by setting
5177 /// \p Signed to true/false, respectively.
5178 static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5179 bool &Signed, ScalarEvolution &SE) {
5180 // The case where Op == SymbolicPHI (that is, with no type conversions on
5181 // the way) is handled by the regular add recurrence creating logic and
5182 // would have already been triggered in createAddRecForPHI. Reaching it here
5183 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5184 // because one of the other operands of the SCEVAddExpr updating this PHI is
5187 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5188 // this case predicates that allow us to prove that Op == SymbolicPHI will
5190 if (Op == SymbolicPHI)
5193 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5194 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5195 if (SourceBits != NewBits)
5198 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5199 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5202 const SCEVTruncateExpr *Trunc =
5203 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5204 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5207 const SCEV *X = Trunc->getOperand();
5208 if (X != SymbolicPHI)
5210 Signed = SExt != nullptr;
5211 return Trunc->getType();
5214 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5215 if (!PN->getType()->isIntegerTy())
5217 const Loop *L = LI.getLoopFor(PN->getParent());
5218 if (!L || L->getHeader() != PN->getParent())
5223 // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5224 // computation that updates the phi follows the following pattern:
5225 // (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5226 // which correspond to a phi->trunc->sext/zext->add->phi update chain.
5227 // If so, try to see if it can be rewritten as an AddRecExpr under some
5228 // Predicates. If successful, return them as a pair. Also cache the results
5231 // Example usage scenario:
5232 // Say the Rewriter is called for the following SCEV:
5233 // 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5235 // %X = phi i64 (%Start, %BEValue)
5236 // It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5237 // and call this function with %SymbolicPHI = %X.
5239 // The analysis will find that the value coming around the backedge has
5240 // the following SCEV:
5241 // BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5242 // Upon concluding that this matches the desired pattern, the function
5243 // will return the pair {NewAddRec, SmallPredsVec} where:
5244 // NewAddRec = {%Start,+,%Step}
5245 // SmallPredsVec = {P1, P2, P3} as follows:
5246 // P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5247 // P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5248 // P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5249 // The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5250 // under the predicates {P1,P2,P3}.
5251 // This predicated rewrite will be cached in PredicatedSCEVRewrites:
5252 // PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5256 // 1) Extend the Induction descriptor to also support inductions that involve
5257 // casts: When needed (namely, when we are called in the context of the
5258 // vectorizer induction analysis), a Set of cast instructions will be
5259 // populated by this method, and provided back to isInductionPHI. This is
5260 // needed to allow the vectorizer to properly record them to be ignored by
5261 // the cost model and to avoid vectorizing them (otherwise these casts,
5262 // which are redundant under the runtime overflow checks, will be
5263 // vectorized, which can be costly).
5265 // 2) Support additional induction/PHISCEV patterns: We also want to support
5266 // inductions where the sext-trunc / zext-trunc operations (partly) occur
5267 // after the induction update operation (the induction increment):
5269 // (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5270 // which correspond to a phi->add->trunc->sext/zext->phi update chain.
5272 // (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5273 // which correspond to a phi->trunc->add->sext/zext->phi update chain.
5275 // 3) Outline common code with createAddRecFromPHI to avoid duplication.
5276 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5277 ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5278 SmallVector<const SCEVPredicate *, 3> Predicates;
5280 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5281 // return an AddRec expression under some predicate.
5283 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5284 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5285 assert(L && "Expecting an integer loop header phi");
5287 // The loop may have multiple entrances or multiple exits; we can analyze
5288 // this phi as an addrec if it has a unique entry value and a unique
5290 Value *BEValueV = nullptr, *StartValueV = nullptr;
5291 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5292 Value *V = PN->getIncomingValue(i);
5293 if (L->contains(PN->getIncomingBlock(i))) {
5296 } else if (BEValueV != V) {
5300 } else if (!StartValueV) {
5302 } else if (StartValueV != V) {
5303 StartValueV = nullptr;
5307 if (!BEValueV || !StartValueV)
5310 const SCEV *BEValue = getSCEV(BEValueV);
5312 // If the value coming around the backedge is an add with the symbolic
5313 // value we just inserted, possibly with casts that we can ignore under
5314 // an appropriate runtime guard, then we found a simple induction variable!
5315 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5319 // If there is a single occurrence of the symbolic value, possibly
5320 // casted, replace it with a recurrence.
5321 unsigned FoundIndex = Add->getNumOperands();
5322 Type *TruncTy = nullptr;
5324 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5326 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5327 if (FoundIndex == e) {
5332 if (FoundIndex == Add->getNumOperands())
5335 // Create an add with everything but the specified operand.
5336 SmallVector<const SCEV *, 8> Ops;
5337 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5338 if (i != FoundIndex)
5339 Ops.push_back(Add->getOperand(i));
5340 const SCEV *Accum = getAddExpr(Ops);
5342 // The runtime checks will not be valid if the step amount is
5343 // varying inside the loop.
5344 if (!isLoopInvariant(Accum, L))
5347 // *** Part2: Create the predicates
5349 // Analysis was successful: we have a phi-with-cast pattern for which we
5350 // can return an AddRec expression under the following predicates:
5352 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5353 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5354 // P2: An Equal predicate that guarantees that
5355 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5356 // P3: An Equal predicate that guarantees that
5357 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5359 // As we next prove, the above predicates guarantee that:
5360 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5363 // More formally, we want to prove that:
5364 // Expr(i+1) = Start + (i+1) * Accum
5365 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5368 // 1) Expr(0) = Start
5369 // 2) Expr(1) = Start + Accum
5370 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5371 // 3) Induction hypothesis (step i):
5372 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5376 // = Start + (i+1)*Accum
5377 // = (Start + i*Accum) + Accum
5378 // = Expr(i) + Accum
5379 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5382 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5384 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5385 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5386 // + Accum :: from P3
5388 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5389 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5391 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5392 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5394 // By induction, the same applies to all iterations 1<=i<n:
5397 // Create a truncated addrec for which we will add a no overflow check (P1).
5398 const SCEV *StartVal = getSCEV(StartValueV);
5399 const SCEV *PHISCEV =
5400 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5401 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5403 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5404 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5405 // will be constant.
5407 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5409 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5410 SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
5411 Signed ? SCEVWrapPredicate::IncrementNSSW
5412 : SCEVWrapPredicate::IncrementNUSW;
5413 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5414 Predicates.push_back(AddRecPred);
5417 // Create the Equal Predicates P2,P3:
5419 // It is possible that the predicates P2 and/or P3 are computable at
5420 // compile time due to StartVal and/or Accum being constants.
5421 // If either one is, then we can check that now and escape if either P2
5424 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5425 // for each of StartVal and Accum
5426 auto getExtendedExpr = [&](const SCEV *Expr,
5427 bool CreateSignExtend) -> const SCEV * {
5428 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5429 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5430 const SCEV *ExtendedExpr =
5431 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5432 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5433 return ExtendedExpr;
5437 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5438 // = getExtendedExpr(Expr)
5439 // Determine whether the predicate P: Expr == ExtendedExpr
5440 // is known to be false at compile time
5441 auto PredIsKnownFalse = [&](const SCEV *Expr,
5442 const SCEV *ExtendedExpr) -> bool {
5443 return Expr != ExtendedExpr &&
5444 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5447 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5448 if (PredIsKnownFalse(StartVal, StartExtended)) {
5449 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5453 // The Step is always Signed (because the overflow checks are either
5455 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5456 if (PredIsKnownFalse(Accum, AccumExtended)) {
5457 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5461 auto AppendPredicate = [&](const SCEV *Expr,
5462 const SCEV *ExtendedExpr) -> void {
5463 if (Expr != ExtendedExpr &&
5464 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5465 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5466 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5467 Predicates.push_back(Pred);
5471 AppendPredicate(StartVal, StartExtended);
5472 AppendPredicate(Accum, AccumExtended);
5474 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5475 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5476 // into NewAR if it will also add the runtime overflow checks specified in
5478 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5480 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5481 std::make_pair(NewAR, Predicates);
5482 // Remember the result of the analysis for this SCEV at this locayyytion.
5483 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5487 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5488 ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
5489 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5490 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5494 // Check to see if we already analyzed this PHI.
5495 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5496 if (I != PredicatedSCEVRewrites.end()) {
5497 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5499 // Analysis was done before and failed to create an AddRec:
5500 if (Rewrite.first == SymbolicPHI)
5502 // Analysis was done before and succeeded to create an AddRec under
5504 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5505 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5509 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5510 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5512 // Record in the cache that the analysis failed
5514 SmallVector<const SCEVPredicate *, 3> Predicates;
5515 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5522 // FIXME: This utility is currently required because the Rewriter currently
5523 // does not rewrite this expression:
5524 // {0, +, (sext ix (trunc iy to ix) to iy)}
5525 // into {0, +, %step},
5526 // even when the following Equal predicate exists:
5527 // "%step == (sext ix (trunc iy to ix) to iy)".
5528 bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5529 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5533 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5534 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5535 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5540 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5541 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5546 /// A helper function for createAddRecFromPHI to handle simple cases.
5548 /// This function tries to find an AddRec expression for the simplest (yet most
5549 /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5550 /// If it fails, createAddRecFromPHI will use a more general, but slow,
5551 /// technique for finding the AddRec expression.
5552 const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5554 Value *StartValueV) {
5555 const Loop *L = LI.getLoopFor(PN->getParent());
5556 assert(L && L->getHeader() == PN->getParent());
5557 assert(BEValueV && StartValueV);
5559 auto BO = MatchBinaryOp(BEValueV, DT);
5563 if (BO->Opcode != Instruction::Add)
5566 const SCEV *Accum = nullptr;
5567 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5568 Accum = getSCEV(BO->RHS);
5569 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5570 Accum = getSCEV(BO->LHS);
5575 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5577 Flags = setFlags(Flags, SCEV::FlagNUW);
5579 Flags = setFlags(Flags, SCEV::FlagNSW);
5581 const SCEV *StartVal = getSCEV(StartValueV);
5582 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5583 insertValueToMap(PN, PHISCEV);
5585 // We can add Flags to the post-inc expression only if we
5586 // know that it is *undefined behavior* for BEValueV to
5588 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5589 assert(isLoopInvariant(Accum, L) &&
5590 "Accum is defined outside L, but is not invariant?");
5591 if (isAddRecNeverPoison(BEInst, L))
5592 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5598 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5599 const Loop *L = LI.getLoopFor(PN->getParent());
5600 if (!L || L->getHeader() != PN->getParent())
5603 // The loop may have multiple entrances or multiple exits; we can analyze
5604 // this phi as an addrec if it has a unique entry value and a unique
5606 Value *BEValueV = nullptr, *StartValueV = nullptr;
5607 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5608 Value *V = PN->getIncomingValue(i);
5609 if (L->contains(PN->getIncomingBlock(i))) {
5612 } else if (BEValueV != V) {
5616 } else if (!StartValueV) {
5618 } else if (StartValueV != V) {
5619 StartValueV = nullptr;
5623 if (!BEValueV || !StartValueV)
5626 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5627 "PHI node already processed?");
5629 // First, try to find AddRec expression without creating a fictituos symbolic
5631 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5634 // Handle PHI node value symbolically.
5635 const SCEV *SymbolicName = getUnknown(PN);
5636 insertValueToMap(PN, SymbolicName);
5638 // Using this symbolic name for the PHI, analyze the value coming around
5640 const SCEV *BEValue = getSCEV(BEValueV);
5642 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5643 // has a special value for the first iteration of the loop.
5645 // If the value coming around the backedge is an add with the symbolic
5646 // value we just inserted, then we found a simple induction variable!
5647 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5648 // If there is a single occurrence of the symbolic value, replace it
5649 // with a recurrence.
5650 unsigned FoundIndex = Add->getNumOperands();
5651 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5652 if (Add->getOperand(i) == SymbolicName)
5653 if (FoundIndex == e) {
5658 if (FoundIndex != Add->getNumOperands()) {
5659 // Create an add with everything but the specified operand.
5660 SmallVector<const SCEV *, 8> Ops;
5661 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5662 if (i != FoundIndex)
5663 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5665 const SCEV *Accum = getAddExpr(Ops);
5667 // This is not a valid addrec if the step amount is varying each
5668 // loop iteration, but is not itself an addrec in this loop.
5669 if (isLoopInvariant(Accum, L) ||
5670 (isa<SCEVAddRecExpr>(Accum) &&
5671 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5672 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5674 if (auto BO = MatchBinaryOp(BEValueV, DT)) {
5675 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5677 Flags = setFlags(Flags, SCEV::FlagNUW);
5679 Flags = setFlags(Flags, SCEV::FlagNSW);
5681 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5682 // If the increment is an inbounds GEP, then we know the address
5683 // space cannot be wrapped around. We cannot make any guarantee
5684 // about signed or unsigned overflow because pointers are
5685 // unsigned but we may have a negative index from the base
5686 // pointer. We can guarantee that no unsigned wrap occurs if the
5687 // indices form a positive value.
5688 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5689 Flags = setFlags(Flags, SCEV::FlagNW);
5691 const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
5692 if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
5693 Flags = setFlags(Flags, SCEV::FlagNUW);
5696 // We cannot transfer nuw and nsw flags from subtraction
5697 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5701 const SCEV *StartVal = getSCEV(StartValueV);
5702 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5704 // Okay, for the entire analysis of this edge we assumed the PHI
5705 // to be symbolic. We now need to go back and purge all of the
5706 // entries for the scalars that use the symbolic expression.
5707 forgetMemoizedResults(SymbolicName);
5708 insertValueToMap(PN, PHISCEV);
5710 // We can add Flags to the post-inc expression only if we
5711 // know that it is *undefined behavior* for BEValueV to
5713 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5714 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5715 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5721 // Otherwise, this could be a loop like this:
5722 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5723 // In this case, j = {1,+,1} and BEValue is j.
5724 // Because the other in-value of i (0) fits the evolution of BEValue
5725 // i really is an addrec evolution.
5727 // We can generalize this saying that i is the shifted value of BEValue
5728 // by one iteration:
5729 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5730 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5731 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5732 if (Shifted != getCouldNotCompute() &&
5733 Start != getCouldNotCompute()) {
5734 const SCEV *StartVal = getSCEV(StartValueV);
5735 if (Start == StartVal) {
5736 // Okay, for the entire analysis of this edge we assumed the PHI
5737 // to be symbolic. We now need to go back and purge all of the
5738 // entries for the scalars that use the symbolic expression.
5739 forgetMemoizedResults(SymbolicName);
5740 insertValueToMap(PN, Shifted);
5746 // Remove the temporary PHI node SCEV that has been inserted while intending
5747 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5748 // as it will prevent later (possibly simpler) SCEV expressions to be added
5749 // to the ValueExprMap.
5750 eraseValueFromMap(PN);
5755 // Checks if the SCEV S is available at BB. S is considered available at BB
5756 // if S can be materialized at BB without introducing a fault.
5757 static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
5759 struct CheckAvailable {
5760 bool TraversalDone = false;
5761 bool Available = true;
5763 const Loop *L = nullptr; // The loop BB is in (can be nullptr)
5764 BasicBlock *BB = nullptr;
5767 CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
5768 : L(L), BB(BB), DT(DT) {}
5770 bool setUnavailable() {
5771 TraversalDone = true;
5776 bool follow(const SCEV *S) {
5777 switch (S->getSCEVType()) {
5789 case scSequentialUMinExpr:
5790 // These expressions are available if their operand(s) is/are.
5793 case scAddRecExpr: {
5794 // We allow add recurrences that are on the loop BB is in, or some
5795 // outer loop. This guarantees availability because the value of the
5796 // add recurrence at BB is simply the "current" value of the induction
5797 // variable. We can relax this in the future; for instance an add
5798 // recurrence on a sibling dominating loop is also available at BB.
5799 const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
5800 if (L && (ARLoop == L || ARLoop->contains(L)))
5803 return setUnavailable();
5807 // For SCEVUnknown, we check for simple dominance.
5808 const auto *SU = cast<SCEVUnknown>(S);
5809 Value *V = SU->getValue();
5811 if (isa<Argument>(V))
5814 if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
5817 return setUnavailable();
5821 case scCouldNotCompute:
5822 // We do not try to smart about these at all.
5823 return setUnavailable();
5825 llvm_unreachable("Unknown SCEV kind!");
5828 bool isDone() { return TraversalDone; }
5831 CheckAvailable CA(L, BB, DT);
5832 SCEVTraversal<CheckAvailable> ST(CA);
5835 return CA.Available;
5838 // Try to match a control flow sequence that branches out at BI and merges back
5839 // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5841 static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
5842 Value *&C, Value *&LHS, Value *&RHS) {
5843 C = BI->getCondition();
5845 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5846 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5848 if (!LeftEdge.isSingleEdge())
5851 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5853 Use &LeftUse = Merge->getOperandUse(0);
5854 Use &RightUse = Merge->getOperandUse(1);
5856 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5862 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5871 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5873 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5874 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5875 const Loop *L = LI.getLoopFor(PN->getParent());
5877 // We don't want to break LCSSA, even in a SCEV expression tree.
5878 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
5879 if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
5884 // br %cond, label %left, label %right
5890 // V = phi [ %x, %left ], [ %y, %right ]
5892 // as "select %cond, %x, %y"
5894 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5895 assert(IDom && "At least the entry block should dominate PN");
5897 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5898 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5900 if (BI && BI->isConditional() &&
5901 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
5902 IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
5903 IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
5904 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
5910 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
5911 if (const SCEV *S = createAddRecFromPHI(PN))
5914 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
5917 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
5920 // If it's not a loop phi, we can't handle it yet.
5921 return getUnknown(PN);
5924 bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
5925 SCEVTypes RootKind) {
5926 struct FindClosure {
5927 const SCEV *OperandToFind;
5928 const SCEVTypes RootKind; // Must be a sequential min/max expression.
5929 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
5933 bool canRecurseInto(SCEVTypes Kind) const {
5934 // We can only recurse into the SCEV expression of the same effective type
5935 // as the type of our root SCEV expression, and into zero-extensions.
5936 return RootKind == Kind || NonSequentialRootKind == Kind ||
5937 scZeroExtend == Kind;
5940 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
5941 : OperandToFind(OperandToFind), RootKind(RootKind),
5942 NonSequentialRootKind(
5943 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
5946 bool follow(const SCEV *S) {
5947 Found = S == OperandToFind;
5949 return !isDone() && canRecurseInto(S->getSCEVType());
5952 bool isDone() const { return Found; }
5955 FindClosure FC(OperandToFind, RootKind);
5960 const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(
5961 Instruction *I, ICmpInst *Cond, Value *TrueVal, Value *FalseVal) {
5962 // Try to match some simple smax or umax patterns.
5965 Value *LHS = ICI->getOperand(0);
5966 Value *RHS = ICI->getOperand(1);
5968 switch (ICI->getPredicate()) {
5969 case ICmpInst::ICMP_SLT:
5970 case ICmpInst::ICMP_SLE:
5971 case ICmpInst::ICMP_ULT:
5972 case ICmpInst::ICMP_ULE:
5973 std::swap(LHS, RHS);
5975 case ICmpInst::ICMP_SGT:
5976 case ICmpInst::ICMP_SGE:
5977 case ICmpInst::ICMP_UGT:
5978 case ICmpInst::ICMP_UGE:
5979 // a > b ? a+x : b+x -> max(a, b)+x
5980 // a > b ? b+x : a+x -> min(a, b)+x
5981 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
5982 bool Signed = ICI->isSigned();
5983 const SCEV *LA = getSCEV(TrueVal);
5984 const SCEV *RA = getSCEV(FalseVal);
5985 const SCEV *LS = getSCEV(LHS);
5986 const SCEV *RS = getSCEV(RHS);
5987 if (LA->getType()->isPointerTy()) {
5988 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
5989 // Need to make sure we can't produce weird expressions involving
5990 // negated pointers.
5991 if (LA == LS && RA == RS)
5992 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
5993 if (LA == RS && RA == LS)
5994 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
5996 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
5997 if (Op->getType()->isPointerTy()) {
5998 Op = getLosslessPtrToIntExpr(Op);
5999 if (isa<SCEVCouldNotCompute>(Op))
6003 Op = getNoopOrSignExtend(Op, I->getType());
6005 Op = getNoopOrZeroExtend(Op, I->getType());
6008 LS = CoerceOperand(LS);
6009 RS = CoerceOperand(RS);
6010 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6012 const SCEV *LDiff = getMinusSCEV(LA, LS);
6013 const SCEV *RDiff = getMinusSCEV(RA, RS);
6015 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6017 LDiff = getMinusSCEV(LA, RS);
6018 RDiff = getMinusSCEV(RA, LS);
6020 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6024 case ICmpInst::ICMP_NE:
6025 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6026 std::swap(TrueVal, FalseVal);
6028 case ICmpInst::ICMP_EQ:
6029 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6030 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
6031 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6032 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
6033 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6034 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6035 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6036 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6037 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6038 return getAddExpr(getUMaxExpr(X, C), Y);
6040 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6041 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6042 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6043 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6044 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6045 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6046 const SCEV *X = getSCEV(LHS);
6047 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6048 X = ZExt->getOperand();
6049 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(I->getType())) {
6050 const SCEV *FalseValExpr = getSCEV(FalseVal);
6051 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6052 return getUMinExpr(getNoopOrZeroExtend(X, I->getType()), FalseValExpr,
6053 /*Sequential=*/true);
6061 return getUnknown(I);
6064 static Optional<const SCEV *>
6065 createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr,
6066 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6067 assert(CondExpr->getType()->isIntegerTy(1) &&
6068 TrueExpr->getType() == FalseExpr->getType() &&
6069 TrueExpr->getType()->isIntegerTy(1) &&
6070 "Unexpected operands of a select.");
6072 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6073 // --> C + (umin_seq cond, x - C)
6075 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6076 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6077 // --> C + (umin_seq ~cond, x - C)
6079 // FIXME: while we can't legally model the case where both of the hands
6080 // are fully variable, we only require that the *difference* is constant.
6081 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6085 if (isa<SCEVConstant>(TrueExpr)) {
6086 CondExpr = SE->getNotSCEV(CondExpr);
6093 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6094 /*Sequential=*/true));
6097 static Optional<const SCEV *> createNodeForSelectViaUMinSeq(ScalarEvolution *SE,
6101 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6104 const auto *SECond = SE->getSCEV(Cond);
6105 const auto *SETrue = SE->getSCEV(TrueVal);
6106 const auto *SEFalse = SE->getSCEV(FalseVal);
6107 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6110 const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6111 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6112 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6113 assert(TrueVal->getType() == FalseVal->getType() &&
6114 V->getType() == TrueVal->getType() &&
6115 "Types of select hands and of the result must match.");
6117 // For now, only deal with i1-typed `select`s.
6118 if (!V->getType()->isIntegerTy(1))
6119 return getUnknown(V);
6121 if (Optional<const SCEV *> S =
6122 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6125 return getUnknown(V);
6128 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6131 // Handle "constant" branch or select. This can occur for instance when a
6132 // loop pass transforms an inner loop and moves on to process the outer loop.
6133 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6134 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6136 if (auto *I = dyn_cast<Instruction>(V)) {
6137 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6138 const SCEV *S = createNodeForSelectOrPHIInstWithICmpInstCond(
6139 I, ICI, TrueVal, FalseVal);
6140 if (!isa<SCEVUnknown>(S))
6145 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6148 /// Expand GEP instructions into add and multiply operations. This allows them
6149 /// to be analyzed by regular SCEV code.
6150 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6151 assert(GEP->getSourceElementType()->isSized() &&
6152 "GEP source element type must be sized");
6154 SmallVector<const SCEV *, 4> IndexExprs;
6155 for (Value *Index : GEP->indices())
6156 IndexExprs.push_back(getSCEV(Index));
6157 return getGEPExpr(GEP, IndexExprs);
6160 uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
6161 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6162 return C->getAPInt().countTrailingZeros();
6164 if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S))
6165 return GetMinTrailingZeros(I->getOperand());
6167 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
6168 return std::min(GetMinTrailingZeros(T->getOperand()),
6169 (uint32_t)getTypeSizeInBits(T->getType()));
6171 if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
6172 uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
6173 return OpRes == getTypeSizeInBits(E->getOperand()->getType())
6174 ? getTypeSizeInBits(E->getType())
6178 if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
6179 uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
6180 return OpRes == getTypeSizeInBits(E->getOperand()->getType())
6181 ? getTypeSizeInBits(E->getType())
6185 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
6186 // The result is the min of all operands results.
6187 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
6188 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
6189 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
6193 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
6194 // The result is the sum of all operands results.
6195 uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
6196 uint32_t BitWidth = getTypeSizeInBits(M->getType());
6197 for (unsigned i = 1, e = M->getNumOperands();
6198 SumOpRes != BitWidth && i != e; ++i)
6200 std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
6204 if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
6205 // The result is the min of all operands results.
6206 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
6207 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
6208 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
6212 if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
6213 // The result is the min of all operands results.
6214 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
6215 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
6216 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
6220 if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
6221 // The result is the min of all operands results.
6222 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
6223 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
6224 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
6228 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
6229 // For a SCEVUnknown, ask ValueTracking.
6230 KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
6231 return Known.countMinTrailingZeros();
6238 uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
6239 auto I = MinTrailingZerosCache.find(S);
6240 if (I != MinTrailingZerosCache.end())
6243 uint32_t Result = GetMinTrailingZerosImpl(S);
6244 auto InsertPair = MinTrailingZerosCache.insert({S, Result});
6245 assert(InsertPair.second && "Should insert a new key");
6246 return InsertPair.first->second;
6249 /// Helper method to assign a range to V from metadata present in the IR.
6250 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6251 if (Instruction *I = dyn_cast<Instruction>(V))
6252 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6253 return getConstantRangeFromMetadata(*MD);
6258 void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec,
6259 SCEV::NoWrapFlags Flags) {
6260 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6261 AddRec->setNoWrapFlags(Flags);
6262 UnsignedRanges.erase(AddRec);
6263 SignedRanges.erase(AddRec);
6267 ConstantRange ScalarEvolution::
6268 getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6269 const DataLayout &DL = getDataLayout();
6271 unsigned BitWidth = getTypeSizeInBits(U->getType());
6272 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6274 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6275 // use information about the trip count to improve our available range. Note
6276 // that the trip count independent cases are already handled by known bits.
6277 // WARNING: The definition of recurrence used here is subtly different than
6278 // the one used by AddRec (and thus most of this file). Step is allowed to
6279 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6280 // and other addrecs in the same loop (for non-affine addrecs). The code
6281 // below intentionally handles the case where step is not loop invariant.
6282 auto *P = dyn_cast<PHINode>(U->getValue());
6286 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6287 // even the values that are not available in these blocks may come from them,
6288 // and this leads to false-positive recurrence test.
6289 for (auto *Pred : predecessors(P->getParent()))
6290 if (!DT.isReachableFromEntry(Pred))
6294 Value *Start, *Step;
6295 if (!matchSimpleRecurrence(P, BO, Start, Step))
6298 // If we found a recurrence in reachable code, we must be in a loop. Note
6299 // that BO might be in some subloop of L, and that's completely okay.
6300 auto *L = LI.getLoopFor(P->getParent());
6301 assert(L && L->getHeader() == P->getParent());
6302 if (!L->contains(BO->getParent()))
6303 // NOTE: This bailout should be an assert instead. However, asserting
6304 // the condition here exposes a case where LoopFusion is querying SCEV
6305 // with malformed loop information during the midst of the transform.
6306 // There doesn't appear to be an obvious fix, so for the moment bailout
6307 // until the caller issue can be fixed. PR49566 tracks the bug.
6310 // TODO: Extend to other opcodes such as mul, and div
6311 switch (BO->getOpcode()) {
6314 case Instruction::AShr:
6315 case Instruction::LShr:
6316 case Instruction::Shl:
6320 if (BO->getOperand(0) != P)
6321 // TODO: Handle the power function forms some day.
6324 unsigned TC = getSmallConstantMaxTripCount(L);
6325 if (!TC || TC >= BitWidth)
6328 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6329 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6330 assert(KnownStart.getBitWidth() == BitWidth &&
6331 KnownStep.getBitWidth() == BitWidth);
6333 // Compute total shift amount, being careful of overflow and bitwidths.
6334 auto MaxShiftAmt = KnownStep.getMaxValue();
6335 APInt TCAP(BitWidth, TC-1);
6336 bool Overflow = false;
6337 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6341 switch (BO->getOpcode()) {
6343 llvm_unreachable("filtered out above");
6344 case Instruction::AShr: {
6345 // For each ashr, three cases:
6346 // shift = 0 => unchanged value
6347 // saturation => 0 or -1
6348 // other => a value closer to zero (of the same sign)
6349 // Thus, the end value is closer to zero than the start.
6350 auto KnownEnd = KnownBits::ashr(KnownStart,
6351 KnownBits::makeConstant(TotalShift));
6352 if (KnownStart.isNonNegative())
6353 // Analogous to lshr (simply not yet canonicalized)
6354 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6355 KnownStart.getMaxValue() + 1);
6356 if (KnownStart.isNegative())
6357 // End >=u Start && End <=s Start
6358 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6359 KnownEnd.getMaxValue() + 1);
6362 case Instruction::LShr: {
6363 // For each lshr, three cases:
6364 // shift = 0 => unchanged value
6366 // other => a smaller positive number
6367 // Thus, the low end of the unsigned range is the last value produced.
6368 auto KnownEnd = KnownBits::lshr(KnownStart,
6369 KnownBits::makeConstant(TotalShift));
6370 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6371 KnownStart.getMaxValue() + 1);
6373 case Instruction::Shl: {
6374 // Iff no bits are shifted out, value increases on every shift.
6375 auto KnownEnd = KnownBits::shl(KnownStart,
6376 KnownBits::makeConstant(TotalShift));
6377 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6378 return ConstantRange(KnownStart.getMinValue(),
6379 KnownEnd.getMaxValue() + 1);
6386 /// Determine the range for a particular SCEV. If SignHint is
6387 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6388 /// with a "cleaner" unsigned (resp. signed) representation.
6389 const ConstantRange &
6390 ScalarEvolution::getRangeRef(const SCEV *S,
6391 ScalarEvolution::RangeSignHint SignHint) {
6392 DenseMap<const SCEV *, ConstantRange> &Cache =
6393 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6395 ConstantRange::PreferredRangeType RangeType =
6396 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
6397 ? ConstantRange::Unsigned : ConstantRange::Signed;
6399 // See if we've computed this range already.
6400 DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
6401 if (I != Cache.end())
6404 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6405 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6407 unsigned BitWidth = getTypeSizeInBits(S->getType());
6408 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6409 using OBO = OverflowingBinaryOperator;
6411 // If the value has known zeros, the maximum value will have those known zeros
6413 uint32_t TZ = GetMinTrailingZeros(S);
6415 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
6416 ConservativeResult =
6417 ConstantRange(APInt::getMinValue(BitWidth),
6418 APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
6420 ConservativeResult = ConstantRange(
6421 APInt::getSignedMinValue(BitWidth),
6422 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6425 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
6426 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
6427 unsigned WrapType = OBO::AnyWrap;
6428 if (Add->hasNoSignedWrap())
6429 WrapType |= OBO::NoSignedWrap;
6430 if (Add->hasNoUnsignedWrap())
6431 WrapType |= OBO::NoUnsignedWrap;
6432 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6433 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint),
6434 WrapType, RangeType);
6435 return setRange(Add, SignHint,
6436 ConservativeResult.intersectWith(X, RangeType));
6439 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
6440 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
6441 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6442 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
6443 return setRange(Mul, SignHint,
6444 ConservativeResult.intersectWith(X, RangeType));
6447 if (isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) {
6449 switch (S->getSCEVType()) {
6451 ID = Intrinsic::umax;
6454 ID = Intrinsic::smax;
6457 case scSequentialUMinExpr:
6458 ID = Intrinsic::umin;
6461 ID = Intrinsic::smin;
6464 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6467 const auto *NAry = cast<SCEVNAryExpr>(S);
6468 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint);
6469 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6470 X = X.intrinsic(ID, {X, getRangeRef(NAry->getOperand(i), SignHint)});
6471 return setRange(S, SignHint,
6472 ConservativeResult.intersectWith(X, RangeType));
6475 if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
6476 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
6477 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
6478 return setRange(UDiv, SignHint,
6479 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6482 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
6483 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
6484 return setRange(ZExt, SignHint,
6485 ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
6489 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
6490 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
6491 return setRange(SExt, SignHint,
6492 ConservativeResult.intersectWith(X.signExtend(BitWidth),
6496 if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
6497 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
6498 return setRange(PtrToInt, SignHint, X);
6501 if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
6502 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
6503 return setRange(Trunc, SignHint,
6504 ConservativeResult.intersectWith(X.truncate(BitWidth),
6508 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
6509 // If there's no unsigned wrap, the value will never be less than its
6511 if (AddRec->hasNoUnsignedWrap()) {
6512 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6513 if (!UnsignedMinValue.isZero())
6514 ConservativeResult = ConservativeResult.intersectWith(
6515 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6518 // If there's no signed wrap, and all the operands except initial value have
6519 // the same sign or zero, the value won't ever be:
6520 // 1: smaller than initial value if operands are non negative,
6521 // 2: bigger than initial value if operands are non positive.
6522 // For both cases, value can not cross signed min/max boundary.
6523 if (AddRec->hasNoSignedWrap()) {
6524 bool AllNonNeg = true;
6525 bool AllNonPos = true;
6526 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6527 if (!isKnownNonNegative(AddRec->getOperand(i)))
6529 if (!isKnownNonPositive(AddRec->getOperand(i)))
6533 ConservativeResult = ConservativeResult.intersectWith(
6534 ConstantRange::getNonEmpty(getSignedRangeMin(AddRec->getStart()),
6535 APInt::getSignedMinValue(BitWidth)),
6538 ConservativeResult = ConservativeResult.intersectWith(
6539 ConstantRange::getNonEmpty(
6540 APInt::getSignedMinValue(BitWidth),
6541 getSignedRangeMax(AddRec->getStart()) + 1),
6545 // TODO: non-affine addrec
6546 if (AddRec->isAffine()) {
6547 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop());
6548 if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
6549 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
6550 auto RangeFromAffine = getRangeForAffineAR(
6551 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
6553 ConservativeResult =
6554 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6556 auto RangeFromFactoring = getRangeViaFactoring(
6557 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
6559 ConservativeResult =
6560 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6563 // Now try symbolic BE count and more powerful methods.
6564 if (UseExpensiveRangeSharpening) {
6565 const SCEV *SymbolicMaxBECount =
6566 getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
6567 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6568 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
6569 AddRec->hasNoSelfWrap()) {
6570 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6571 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6572 ConservativeResult =
6573 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6578 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6581 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
6583 // Check if the IR explicitly contains !range metadata.
6584 Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
6586 ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue(),
6589 // Use facts about recurrences in the underlying IR. Note that add
6590 // recurrences are AddRecExprs and thus don't hit this path. This
6591 // primarily handles shift recurrences.
6592 auto CR = getRangeForUnknownRecurrence(U);
6593 ConservativeResult = ConservativeResult.intersectWith(CR);
6595 // See if ValueTracking can give us a useful range.
6596 const DataLayout &DL = getDataLayout();
6597 KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
6598 if (Known.getBitWidth() != BitWidth)
6599 Known = Known.zextOrTrunc(BitWidth);
6601 // ValueTracking may be able to compute a tighter result for the number of
6602 // sign bits than for the value of those sign bits.
6603 unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
6604 if (U->getType()->isPointerTy()) {
6605 // If the pointer size is larger than the index size type, this can cause
6606 // NS to be larger than BitWidth. So compensate for this.
6607 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6608 int ptrIdxDiff = ptrSize - BitWidth;
6609 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6614 // If we know any of the sign bits, we know all of the sign bits.
6615 if (!Known.Zero.getHiBits(NS).isZero())
6616 Known.Zero.setHighBits(NS);
6617 if (!Known.One.getHiBits(NS).isZero())
6618 Known.One.setHighBits(NS);
6621 if (Known.getMinValue() != Known.getMaxValue() + 1)
6622 ConservativeResult = ConservativeResult.intersectWith(
6623 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6626 ConservativeResult = ConservativeResult.intersectWith(
6627 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6628 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6631 // A range of Phi is a subset of union of all ranges of its input.
6632 if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
6633 // Make sure that we do not run over cycled Phis.
6634 if (PendingPhiRanges.insert(Phi).second) {
6635 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6636 for (auto &Op : Phi->operands()) {
6637 auto OpRange = getRangeRef(getSCEV(Op), SignHint);
6638 RangeFromOps = RangeFromOps.unionWith(OpRange);
6639 // No point to continue if we already have a full set.
6640 if (RangeFromOps.isFullSet())
6643 ConservativeResult =
6644 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6645 bool Erased = PendingPhiRanges.erase(Phi);
6646 assert(Erased && "Failed to erase Phi properly?");
6651 return setRange(U, SignHint, std::move(ConservativeResult));
6654 return setRange(S, SignHint, std::move(ConservativeResult));
6657 // Given a StartRange, Step and MaxBECount for an expression compute a range of
6658 // values that the expression can take. Initially, the expression has a value
6659 // from StartRange and then is changed by Step up to MaxBECount times. Signed
6660 // argument defines if we treat Step as signed or unsigned.
6661 static ConstantRange getRangeForAffineARHelper(APInt Step,
6662 const ConstantRange &StartRange,
6663 const APInt &MaxBECount,
6664 unsigned BitWidth, bool Signed) {
6665 // If either Step or MaxBECount is 0, then the expression won't change, and we
6666 // just need to return the initial range.
6667 if (Step == 0 || MaxBECount == 0)
6670 // If we don't know anything about the initial value (i.e. StartRange is
6671 // FullRange), then we don't know anything about the final range either.
6672 // Return FullRange.
6673 if (StartRange.isFullSet())
6674 return ConstantRange::getFull(BitWidth);
6676 // If Step is signed and negative, then we use its absolute value, but we also
6677 // note that we're moving in the opposite direction.
6678 bool Descending = Signed && Step.isNegative();
6681 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6682 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6683 // This equations hold true due to the well-defined wrap-around behavior of
6687 // Check if Offset is more than full span of BitWidth. If it is, the
6688 // expression is guaranteed to overflow.
6689 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6690 return ConstantRange::getFull(BitWidth);
6692 // Offset is by how much the expression can change. Checks above guarantee no
6694 APInt Offset = Step * MaxBECount;
6696 // Minimum value of the final range will match the minimal value of StartRange
6697 // if the expression is increasing and will be decreased by Offset otherwise.
6698 // Maximum value of the final range will match the maximal value of StartRange
6699 // if the expression is decreasing and will be increased by Offset otherwise.
6700 APInt StartLower = StartRange.getLower();
6701 APInt StartUpper = StartRange.getUpper() - 1;
6702 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6703 : (StartUpper + std::move(Offset));
6705 // It's possible that the new minimum/maximum value will fall into the initial
6706 // range (due to wrap around). This means that the expression can take any
6707 // value in this bitwidth, and we have to return full range.
6708 if (StartRange.contains(MovedBoundary))
6709 return ConstantRange::getFull(BitWidth);
6712 Descending ? std::move(MovedBoundary) : std::move(StartLower);
6714 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
6717 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
6718 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
6721 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
6723 const SCEV *MaxBECount,
6724 unsigned BitWidth) {
6725 assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
6726 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
6729 MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
6730 APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
6732 // First, consider step signed.
6733 ConstantRange StartSRange = getSignedRange(Start);
6734 ConstantRange StepSRange = getSignedRange(Step);
6736 // If Step can be both positive and negative, we need to find ranges for the
6737 // maximum absolute step values in both directions and union them.
6739 getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
6740 MaxBECountValue, BitWidth, /* Signed = */ true);
6741 SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
6742 StartSRange, MaxBECountValue,
6743 BitWidth, /* Signed = */ true));
6745 // Next, consider step unsigned.
6746 ConstantRange UR = getRangeForAffineARHelper(
6747 getUnsignedRangeMax(Step), getUnsignedRange(Start),
6748 MaxBECountValue, BitWidth, /* Signed = */ false);
6750 // Finally, intersect signed and unsigned ranges.
6751 return SR.intersectWith(UR, ConstantRange::Smallest);
6754 ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
6755 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
6756 ScalarEvolution::RangeSignHint SignHint) {
6757 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
6758 assert(AddRec->hasNoSelfWrap() &&
6759 "This only works for non-self-wrapping AddRecs!");
6760 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
6761 const SCEV *Step = AddRec->getStepRecurrence(*this);
6762 // Only deal with constant step to save compile time.
6763 if (!isa<SCEVConstant>(Step))
6764 return ConstantRange::getFull(BitWidth);
6765 // Let's make sure that we can prove that we do not self-wrap during
6766 // MaxBECount iterations. We need this because MaxBECount is a maximum
6767 // iteration count estimate, and we might infer nw from some exit for which we
6768 // do not know max exit count (or any other side reasoning).
6769 // TODO: Turn into assert at some point.
6770 if (getTypeSizeInBits(MaxBECount->getType()) >
6771 getTypeSizeInBits(AddRec->getType()))
6772 return ConstantRange::getFull(BitWidth);
6773 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
6774 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
6775 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
6776 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
6777 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
6778 MaxItersWithoutWrap))
6779 return ConstantRange::getFull(BitWidth);
6781 ICmpInst::Predicate LEPred =
6782 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
6783 ICmpInst::Predicate GEPred =
6784 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
6785 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
6787 // We know that there is no self-wrap. Let's take Start and End values and
6788 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
6789 // the iteration. They either lie inside the range [Min(Start, End),
6790 // Max(Start, End)] or outside it:
6792 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
6793 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
6795 // No self wrap flag guarantees that the intermediate values cannot be BOTH
6796 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
6797 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
6798 // Start <= End and step is positive, or Start >= End and step is negative.
6799 const SCEV *Start = AddRec->getStart();
6800 ConstantRange StartRange = getRangeRef(Start, SignHint);
6801 ConstantRange EndRange = getRangeRef(End, SignHint);
6802 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
6803 // If they already cover full iteration space, we will know nothing useful
6804 // even if we prove what we want to prove.
6805 if (RangeBetween.isFullSet())
6806 return RangeBetween;
6807 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
6808 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
6809 : RangeBetween.isWrappedSet();
6811 return ConstantRange::getFull(BitWidth);
6813 if (isKnownPositive(Step) &&
6814 isKnownPredicateViaConstantRanges(LEPred, Start, End))
6815 return RangeBetween;
6816 else if (isKnownNegative(Step) &&
6817 isKnownPredicateViaConstantRanges(GEPred, Start, End))
6818 return RangeBetween;
6819 return ConstantRange::getFull(BitWidth);
6822 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
6824 const SCEV *MaxBECount,
6825 unsigned BitWidth) {
6826 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
6827 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
6829 struct SelectPattern {
6830 Value *Condition = nullptr;
6834 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
6836 Optional<unsigned> CastOp;
6837 APInt Offset(BitWidth, 0);
6839 assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
6842 // Peel off a constant offset:
6843 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
6844 // In the future we could consider being smarter here and handle
6845 // {Start+Step,+,Step} too.
6846 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
6849 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
6850 S = SA->getOperand(1);
6853 // Peel off a cast operation
6854 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
6855 CastOp = SCast->getSCEVType();
6856 S = SCast->getOperand();
6859 using namespace llvm::PatternMatch;
6861 auto *SU = dyn_cast<SCEVUnknown>(S);
6862 const APInt *TrueVal, *FalseVal;
6864 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
6865 m_APInt(FalseVal)))) {
6866 Condition = nullptr;
6870 TrueValue = *TrueVal;
6871 FalseValue = *FalseVal;
6873 // Re-apply the cast we peeled off earlier
6877 llvm_unreachable("Unknown SCEV cast type!");
6880 TrueValue = TrueValue.trunc(BitWidth);
6881 FalseValue = FalseValue.trunc(BitWidth);
6884 TrueValue = TrueValue.zext(BitWidth);
6885 FalseValue = FalseValue.zext(BitWidth);
6888 TrueValue = TrueValue.sext(BitWidth);
6889 FalseValue = FalseValue.sext(BitWidth);
6893 // Re-apply the constant offset we peeled off earlier
6894 TrueValue += Offset;
6895 FalseValue += Offset;
6898 bool isRecognized() { return Condition != nullptr; }
6901 SelectPattern StartPattern(*this, BitWidth, Start);
6902 if (!StartPattern.isRecognized())
6903 return ConstantRange::getFull(BitWidth);
6905 SelectPattern StepPattern(*this, BitWidth, Step);
6906 if (!StepPattern.isRecognized())
6907 return ConstantRange::getFull(BitWidth);
6909 if (StartPattern.Condition != StepPattern.Condition) {
6910 // We don't handle this case today; but we could, by considering four
6911 // possibilities below instead of two. I'm not sure if there are cases where
6912 // that will help over what getRange already does, though.
6913 return ConstantRange::getFull(BitWidth);
6916 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
6917 // construct arbitrary general SCEV expressions here. This function is called
6918 // from deep in the call stack, and calling getSCEV (on a sext instruction,
6919 // say) can end up caching a suboptimal value.
6921 // FIXME: without the explicit `this` receiver below, MSVC errors out with
6922 // C2352 and C2512 (otherwise it isn't needed).
6924 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
6925 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
6926 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
6927 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
6929 ConstantRange TrueRange =
6930 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
6931 ConstantRange FalseRange =
6932 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
6934 return TrueRange.unionWith(FalseRange);
6937 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
6938 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
6939 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
6941 // Return early if there are no flags to propagate to the SCEV.
6942 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
6943 if (BinOp->hasNoUnsignedWrap())
6944 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
6945 if (BinOp->hasNoSignedWrap())
6946 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
6947 if (Flags == SCEV::FlagAnyWrap)
6948 return SCEV::FlagAnyWrap;
6950 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
6954 ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
6955 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
6956 return &*AddRec->getLoop()->getHeader()->begin();
6957 if (auto *U = dyn_cast<SCEVUnknown>(S))
6958 if (auto *I = dyn_cast<Instruction>(U->getValue()))
6963 /// Fills \p Ops with unique operands of \p S, if it has operands. If not,
6964 /// \p Ops remains unmodified.
6965 static void collectUniqueOps(const SCEV *S,
6966 SmallVectorImpl<const SCEV *> &Ops) {
6967 SmallPtrSet<const SCEV *, 4> Unique;
6968 auto InsertUnique = [&](const SCEV *S) {
6969 if (Unique.insert(S).second)
6972 if (auto *S2 = dyn_cast<SCEVCastExpr>(S))
6973 for (auto *Op : S2->operands())
6975 else if (auto *S2 = dyn_cast<SCEVNAryExpr>(S))
6976 for (auto *Op : S2->operands())
6978 else if (auto *S2 = dyn_cast<SCEVUDivExpr>(S))
6979 for (auto *Op : S2->operands())
6984 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
6987 // Do a bounded search of the def relation of the requested SCEVs.
6988 SmallSet<const SCEV *, 16> Visited;
6989 SmallVector<const SCEV *> Worklist;
6990 auto pushOp = [&](const SCEV *S) {
6991 if (!Visited.insert(S).second)
6993 // Threshold of 30 here is arbitrary.
6994 if (Visited.size() > 30) {
6998 Worklist.push_back(S);
7004 const Instruction *Bound = nullptr;
7005 while (!Worklist.empty()) {
7006 auto *S = Worklist.pop_back_val();
7007 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7008 if (!Bound || DT.dominates(Bound, DefI))
7011 SmallVector<const SCEV *, 4> Ops;
7012 collectUniqueOps(S, Ops);
7013 for (auto *Op : Ops)
7017 return Bound ? Bound : &*F.getEntryBlock().begin();
7021 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7023 return getDefiningScopeBound(Ops, Discard);
7026 bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7027 const Instruction *B) {
7028 if (A->getParent() == B->getParent() &&
7029 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
7033 auto *BLoop = LI.getLoopFor(B->getParent());
7034 if (BLoop && BLoop->getHeader() == B->getParent() &&
7035 BLoop->getLoopPreheader() == A->getParent() &&
7036 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
7037 A->getParent()->end()) &&
7038 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7045 bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7046 // Only proceed if we can prove that I does not yield poison.
7047 if (!programUndefinedIfPoison(I))
7050 // At this point we know that if I is executed, then it does not wrap
7051 // according to at least one of NSW or NUW. If I is not executed, then we do
7052 // not know if the calculation that I represents would wrap. Multiple
7053 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7054 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7055 // derived from other instructions that map to the same SCEV. We cannot make
7056 // that guarantee for cases where I is not executed. So we need to find a
7057 // upper bound on the defining scope for the SCEV, and prove that I is
7058 // executed every time we enter that scope. When the bounding scope is a
7059 // loop (the common case), this is equivalent to proving I executes on every
7060 // iteration of that loop.
7061 SmallVector<const SCEV *> SCEVOps;
7062 for (const Use &Op : I->operands()) {
7063 // I could be an extractvalue from a call to an overflow intrinsic.
7064 // TODO: We can do better here in some cases.
7065 if (isSCEVable(Op->getType()))
7066 SCEVOps.push_back(getSCEV(Op));
7068 auto *DefI = getDefiningScopeBound(SCEVOps);
7069 return isGuaranteedToTransferExecutionTo(DefI, I);
7072 bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7073 // If we know that \c I can never be poison period, then that's enough.
7074 if (isSCEVExprNeverPoison(I))
7077 // For an add recurrence specifically, we assume that infinite loops without
7078 // side effects are undefined behavior, and then reason as follows:
7080 // If the add recurrence is poison in any iteration, it is poison on all
7081 // future iterations (since incrementing poison yields poison). If the result
7082 // of the add recurrence is fed into the loop latch condition and the loop
7083 // does not contain any throws or exiting blocks other than the latch, we now
7084 // have the ability to "choose" whether the backedge is taken or not (by
7085 // choosing a sufficiently evil value for the poison feeding into the branch)
7086 // for every iteration including and after the one in which \p I first became
7087 // poison. There are two possibilities (let's call the iteration in which \p
7088 // I first became poison as K):
7090 // 1. In the set of iterations including and after K, the loop body executes
7091 // no side effects. In this case executing the backege an infinte number
7092 // of times will yield undefined behavior.
7094 // 2. In the set of iterations including and after K, the loop body executes
7095 // at least one side effect. In this case, that specific instance of side
7096 // effect is control dependent on poison, which also yields undefined
7099 auto *ExitingBB = L->getExitingBlock();
7100 auto *LatchBB = L->getLoopLatch();
7101 if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
7104 SmallPtrSet<const Instruction *, 16> Pushed;
7105 SmallVector<const Instruction *, 8> PoisonStack;
7107 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7108 // things that are known to be poison under that assumption go on the
7111 PoisonStack.push_back(I);
7113 bool LatchControlDependentOnPoison = false;
7114 while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
7115 const Instruction *Poison = PoisonStack.pop_back_val();
7117 for (auto *PoisonUser : Poison->users()) {
7118 if (propagatesPoison(cast<Operator>(PoisonUser))) {
7119 if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
7120 PoisonStack.push_back(cast<Instruction>(PoisonUser));
7121 } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
7122 assert(BI->isConditional() && "Only possibility!");
7123 if (BI->getParent() == LatchBB) {
7124 LatchControlDependentOnPoison = true;
7131 return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
7134 ScalarEvolution::LoopProperties
7135 ScalarEvolution::getLoopProperties(const Loop *L) {
7136 using LoopProperties = ScalarEvolution::LoopProperties;
7138 auto Itr = LoopPropertiesCache.find(L);
7139 if (Itr == LoopPropertiesCache.end()) {
7140 auto HasSideEffects = [](Instruction *I) {
7141 if (auto *SI = dyn_cast<StoreInst>(I))
7142 return !SI->isSimple();
7144 return I->mayThrow() || I->mayWriteToMemory();
7147 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7148 /*HasNoSideEffects*/ true};
7150 for (auto *BB : L->getBlocks())
7151 for (auto &I : *BB) {
7152 if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7153 LP.HasNoAbnormalExits = false;
7154 if (HasSideEffects(&I))
7155 LP.HasNoSideEffects = false;
7156 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7157 break; // We're already as pessimistic as we can get.
7160 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7161 assert(InsertPair.second && "We just checked!");
7162 Itr = InsertPair.first;
7168 bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
7169 // A mustprogress loop without side effects must be finite.
7170 // TODO: The check used here is very conservative. It's only *specific*
7171 // side effects which are well defined in infinite loops.
7172 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7175 const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7176 // Worklist item with a Value and a bool indicating whether all operands have
7177 // been visited already.
7178 using PointerTy = PointerIntPair<Value *, 1, bool>;
7179 SmallVector<PointerTy> Stack;
7181 Stack.emplace_back(V, true);
7182 Stack.emplace_back(V, false);
7183 while (!Stack.empty()) {
7184 auto E = Stack.pop_back_val();
7185 Value *CurV = E.getPointer();
7187 if (getExistingSCEV(CurV))
7190 SmallVector<Value *> Ops;
7191 const SCEV *CreatedSCEV = nullptr;
7192 // If all operands have been visited already, create the SCEV.
7194 CreatedSCEV = createSCEV(CurV);
7196 // Otherwise get the operands we need to create SCEV's for before creating
7197 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7199 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7203 insertValueToMap(CurV, CreatedSCEV);
7205 // Queue CurV for SCEV creation, followed by its's operands which need to
7206 // be constructed first.
7207 Stack.emplace_back(CurV, true);
7208 for (Value *Op : Ops)
7209 Stack.emplace_back(Op, false);
7213 return getExistingSCEV(V);
7217 ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7218 if (!isSCEVable(V->getType()))
7219 return getUnknown(V);
7221 if (Instruction *I = dyn_cast<Instruction>(V)) {
7222 // Don't attempt to analyze instructions in blocks that aren't
7223 // reachable. Such instructions don't matter, and they aren't required
7224 // to obey basic rules for definitions dominating uses which this
7225 // analysis depends on.
7226 if (!DT.isReachableFromEntry(I->getParent()))
7227 return getUnknown(PoisonValue::get(V->getType()));
7228 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7229 return getConstant(CI);
7230 else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
7231 if (!GA->isInterposable()) {
7232 Ops.push_back(GA->getAliasee());
7235 return getUnknown(V);
7236 } else if (!isa<ConstantExpr>(V))
7237 return getUnknown(V);
7239 Operator *U = cast<Operator>(V);
7240 if (auto BO = MatchBinaryOp(U, DT)) {
7241 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7242 switch (U->getOpcode()) {
7243 case Instruction::Add: {
7244 // For additions and multiplications, traverse add/mul chains for which we
7245 // can potentially create a single SCEV, to reduce the number of
7246 // get{Add,Mul}Expr calls.
7249 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7250 Ops.push_back(BO->Op);
7254 Ops.push_back(BO->RHS);
7255 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7256 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7257 NewBO->Opcode != Instruction::Sub)) {
7258 Ops.push_back(BO->LHS);
7266 case Instruction::Mul: {
7269 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7270 Ops.push_back(BO->Op);
7274 Ops.push_back(BO->RHS);
7275 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7276 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7277 Ops.push_back(BO->LHS);
7285 case Instruction::AShr:
7286 case Instruction::Shl:
7287 case Instruction::Xor:
7291 case Instruction::And:
7292 case Instruction::Or:
7293 if (!IsConstArg && BO->LHS->getType()->isIntegerTy(1))
7300 Ops.push_back(BO->LHS);
7301 Ops.push_back(BO->RHS);
7305 switch (U->getOpcode()) {
7306 case Instruction::Trunc:
7307 case Instruction::ZExt:
7308 case Instruction::SExt:
7309 case Instruction::PtrToInt:
7310 Ops.push_back(U->getOperand(0));
7313 case Instruction::BitCast:
7314 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7315 Ops.push_back(U->getOperand(0));
7318 return getUnknown(V);
7320 case Instruction::SDiv:
7321 case Instruction::SRem:
7322 Ops.push_back(U->getOperand(0));
7323 Ops.push_back(U->getOperand(1));
7326 case Instruction::GetElementPtr:
7327 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7328 "GEP source element type must be sized");
7329 for (Value *Index : U->operands())
7330 Ops.push_back(Index);
7333 case Instruction::IntToPtr:
7334 return getUnknown(V);
7336 case Instruction::PHI:
7337 // Keep constructing SCEVs' for phis recursively for now.
7340 case Instruction::Select:
7341 for (Value *Inc : U->operands())
7346 case Instruction::Call:
7347 case Instruction::Invoke:
7348 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7353 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7354 switch (II->getIntrinsicID()) {
7355 case Intrinsic::abs:
7356 Ops.push_back(II->getArgOperand(0));
7358 case Intrinsic::umax:
7359 case Intrinsic::umin:
7360 case Intrinsic::smax:
7361 case Intrinsic::smin:
7362 case Intrinsic::usub_sat:
7363 case Intrinsic::uadd_sat:
7364 Ops.push_back(II->getArgOperand(0));
7365 Ops.push_back(II->getArgOperand(1));
7367 case Intrinsic::start_loop_iterations:
7368 Ops.push_back(II->getArgOperand(0));
7380 const SCEV *ScalarEvolution::createSCEV(Value *V) {
7381 if (!isSCEVable(V->getType()))
7382 return getUnknown(V);
7384 if (Instruction *I = dyn_cast<Instruction>(V)) {
7385 // Don't attempt to analyze instructions in blocks that aren't
7386 // reachable. Such instructions don't matter, and they aren't required
7387 // to obey basic rules for definitions dominating uses which this
7388 // analysis depends on.
7389 if (!DT.isReachableFromEntry(I->getParent()))
7390 return getUnknown(PoisonValue::get(V->getType()));
7391 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7392 return getConstant(CI);
7393 else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
7394 return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
7395 else if (!isa<ConstantExpr>(V))
7396 return getUnknown(V);
7401 Operator *U = cast<Operator>(V);
7402 if (auto BO = MatchBinaryOp(U, DT)) {
7403 switch (BO->Opcode) {
7404 case Instruction::Add: {
7405 // The simple thing to do would be to just call getSCEV on both operands
7406 // and call getAddExpr with the result. However if we're looking at a
7407 // bunch of things all added together, this can be quite inefficient,
7408 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7409 // Instead, gather up all the operands and make a single getAddExpr call.
7410 // LLVM IR canonical form means we need only traverse the left operands.
7411 SmallVector<const SCEV *, 4> AddOps;
7414 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7415 AddOps.push_back(OpSCEV);
7419 // If a NUW or NSW flag can be applied to the SCEV for this
7420 // addition, then compute the SCEV for this addition by itself
7421 // with a separate call to getAddExpr. We need to do that
7422 // instead of pushing the operands of the addition onto AddOps,
7423 // since the flags are only known to apply to this particular
7424 // addition - they may not apply to other additions that can be
7425 // formed with operands from AddOps.
7426 const SCEV *RHS = getSCEV(BO->RHS);
7427 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7428 if (Flags != SCEV::FlagAnyWrap) {
7429 const SCEV *LHS = getSCEV(BO->LHS);
7430 if (BO->Opcode == Instruction::Sub)
7431 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7433 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7438 if (BO->Opcode == Instruction::Sub)
7439 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7441 AddOps.push_back(getSCEV(BO->RHS));
7443 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7444 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7445 NewBO->Opcode != Instruction::Sub)) {
7446 AddOps.push_back(getSCEV(BO->LHS));
7452 return getAddExpr(AddOps);
7455 case Instruction::Mul: {
7456 SmallVector<const SCEV *, 4> MulOps;
7459 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7460 MulOps.push_back(OpSCEV);
7464 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7465 if (Flags != SCEV::FlagAnyWrap) {
7466 LHS = getSCEV(BO->LHS);
7467 RHS = getSCEV(BO->RHS);
7468 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7473 MulOps.push_back(getSCEV(BO->RHS));
7474 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7475 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7476 MulOps.push_back(getSCEV(BO->LHS));
7482 return getMulExpr(MulOps);
7484 case Instruction::UDiv:
7485 LHS = getSCEV(BO->LHS);
7486 RHS = getSCEV(BO->RHS);
7487 return getUDivExpr(LHS, RHS);
7488 case Instruction::URem:
7489 LHS = getSCEV(BO->LHS);
7490 RHS = getSCEV(BO->RHS);
7491 return getURemExpr(LHS, RHS);
7492 case Instruction::Sub: {
7493 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
7495 Flags = getNoWrapFlagsFromUB(BO->Op);
7496 LHS = getSCEV(BO->LHS);
7497 RHS = getSCEV(BO->RHS);
7498 return getMinusSCEV(LHS, RHS, Flags);
7500 case Instruction::And:
7501 // For an expression like x&255 that merely masks off the high bits,
7502 // use zext(trunc(x)) as the SCEV expression.
7503 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7505 return getSCEV(BO->RHS);
7506 if (CI->isMinusOne())
7507 return getSCEV(BO->LHS);
7508 const APInt &A = CI->getValue();
7510 // Instcombine's ShrinkDemandedConstant may strip bits out of
7511 // constants, obscuring what would otherwise be a low-bits mask.
7512 // Use computeKnownBits to compute what ShrinkDemandedConstant
7513 // knew about to reconstruct a low-bits mask value.
7514 unsigned LZ = A.countLeadingZeros();
7515 unsigned TZ = A.countTrailingZeros();
7516 unsigned BitWidth = A.getBitWidth();
7517 KnownBits Known(BitWidth);
7518 computeKnownBits(BO->LHS, Known, getDataLayout(),
7519 0, &AC, nullptr, &DT);
7521 APInt EffectiveMask =
7522 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7523 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7524 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7525 const SCEV *LHS = getSCEV(BO->LHS);
7526 const SCEV *ShiftedLHS = nullptr;
7527 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7528 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7529 // For an expression like (x * 8) & 8, simplify the multiply.
7530 unsigned MulZeros = OpC->getAPInt().countTrailingZeros();
7531 unsigned GCD = std::min(MulZeros, TZ);
7532 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7533 SmallVector<const SCEV*, 4> MulOps;
7534 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7535 MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end());
7536 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7537 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7541 ShiftedLHS = getUDivExpr(LHS, MulCount);
7544 getTruncateExpr(ShiftedLHS,
7545 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7546 BO->LHS->getType()),
7550 // Binary `and` is a bit-wise `umin`.
7551 if (BO->LHS->getType()->isIntegerTy(1)) {
7552 LHS = getSCEV(BO->LHS);
7553 RHS = getSCEV(BO->RHS);
7554 return getUMinExpr(LHS, RHS);
7558 case Instruction::Or:
7559 // If the RHS of the Or is a constant, we may have something like:
7560 // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
7561 // optimizations will transparently handle this case.
7563 // In order for this transformation to be safe, the LHS must be of the
7564 // form X*(2^n) and the Or constant must be less than 2^n.
7565 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7566 const SCEV *LHS = getSCEV(BO->LHS);
7567 const APInt &CIVal = CI->getValue();
7568 if (GetMinTrailingZeros(LHS) >=
7569 (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
7570 // Build a plain add SCEV.
7571 return getAddExpr(LHS, getSCEV(CI),
7572 (SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW));
7575 // Binary `or` is a bit-wise `umax`.
7576 if (BO->LHS->getType()->isIntegerTy(1)) {
7577 LHS = getSCEV(BO->LHS);
7578 RHS = getSCEV(BO->RHS);
7579 return getUMaxExpr(LHS, RHS);
7583 case Instruction::Xor:
7584 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7585 // If the RHS of xor is -1, then this is a not operation.
7586 if (CI->isMinusOne())
7587 return getNotSCEV(getSCEV(BO->LHS));
7589 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7590 // This is a variant of the check for xor with -1, and it handles
7591 // the case where instcombine has trimmed non-demanded bits out
7592 // of an xor with -1.
7593 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7594 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7595 if (LBO->getOpcode() == Instruction::And &&
7596 LCI->getValue() == CI->getValue())
7597 if (const SCEVZeroExtendExpr *Z =
7598 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7599 Type *UTy = BO->LHS->getType();
7600 const SCEV *Z0 = Z->getOperand();
7601 Type *Z0Ty = Z0->getType();
7602 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7604 // If C is a low-bits mask, the zero extend is serving to
7605 // mask off the high bits. Complement the operand and
7606 // re-apply the zext.
7607 if (CI->getValue().isMask(Z0TySize))
7608 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7610 // If C is a single bit, it may be in the sign-bit position
7611 // before the zero-extend. In this case, represent the xor
7612 // using an add, which is equivalent, and re-apply the zext.
7613 APInt Trunc = CI->getValue().trunc(Z0TySize);
7614 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7616 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7622 case Instruction::Shl:
7623 // Turn shift left of a constant amount into a multiply.
7624 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7625 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7627 // If the shift count is not less than the bitwidth, the result of
7628 // the shift is undefined. Don't try to analyze it, because the
7629 // resolution chosen here may differ from the resolution chosen in
7630 // other parts of the compiler.
7631 if (SA->getValue().uge(BitWidth))
7634 // We can safely preserve the nuw flag in all cases. It's also safe to
7635 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7636 // requires special handling. It can be preserved as long as we're not
7637 // left shifting by bitwidth - 1.
7638 auto Flags = SCEV::FlagAnyWrap;
7640 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7641 if ((MulFlags & SCEV::FlagNSW) &&
7642 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7643 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW);
7644 if (MulFlags & SCEV::FlagNUW)
7645 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW);
7648 ConstantInt *X = ConstantInt::get(
7649 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7650 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7654 case Instruction::AShr: {
7655 // AShr X, C, where C is a constant.
7656 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7660 Type *OuterTy = BO->LHS->getType();
7661 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
7662 // If the shift count is not less than the bitwidth, the result of
7663 // the shift is undefined. Don't try to analyze it, because the
7664 // resolution chosen here may differ from the resolution chosen in
7665 // other parts of the compiler.
7666 if (CI->getValue().uge(BitWidth))
7670 return getSCEV(BO->LHS); // shift by zero --> noop
7672 uint64_t AShrAmt = CI->getZExtValue();
7673 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7675 Operator *L = dyn_cast<Operator>(BO->LHS);
7676 if (L && L->getOpcode() == Instruction::Shl) {
7679 // Both n and m are constant.
7681 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7682 if (L->getOperand(1) == BO->RHS)
7683 // For a two-shift sext-inreg, i.e. n = m,
7684 // use sext(trunc(x)) as the SCEV expression.
7685 return getSignExtendExpr(
7686 getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
7688 ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7689 if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
7690 uint64_t ShlAmt = ShlAmtCI->getZExtValue();
7691 if (ShlAmt > AShrAmt) {
7692 // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7693 // expression. We already checked that ShlAmt < BitWidth, so
7694 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7695 // ShlAmt - AShrAmt < Amt.
7696 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
7698 return getSignExtendExpr(
7699 getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
7700 getConstant(Mul)), OuterTy);
7709 switch (U->getOpcode()) {
7710 case Instruction::Trunc:
7711 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
7713 case Instruction::ZExt:
7714 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7716 case Instruction::SExt:
7717 if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
7718 // The NSW flag of a subtract does not always survive the conversion to
7719 // A + (-1)*B. By pushing sign extension onto its operands we are much
7720 // more likely to preserve NSW and allow later AddRec optimisations.
7722 // NOTE: This is effectively duplicating this logic from getSignExtend:
7723 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
7724 // but by that point the NSW information has potentially been lost.
7725 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
7726 Type *Ty = U->getType();
7727 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
7728 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
7729 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
7732 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7734 case Instruction::BitCast:
7735 // BitCasts are no-op casts so we just eliminate the cast.
7736 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
7737 return getSCEV(U->getOperand(0));
7740 case Instruction::PtrToInt: {
7741 // Pointer to integer cast is straight-forward, so do model it.
7742 const SCEV *Op = getSCEV(U->getOperand(0));
7743 Type *DstIntTy = U->getType();
7744 // But only if effective SCEV (integer) type is wide enough to represent
7745 // all possible pointer values.
7746 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
7747 if (isa<SCEVCouldNotCompute>(IntOp))
7748 return getUnknown(V);
7751 case Instruction::IntToPtr:
7752 // Just don't deal with inttoptr casts.
7753 return getUnknown(V);
7755 case Instruction::SDiv:
7756 // If both operands are non-negative, this is just an udiv.
7757 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7758 isKnownNonNegative(getSCEV(U->getOperand(1))))
7759 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7762 case Instruction::SRem:
7763 // If both operands are non-negative, this is just an urem.
7764 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7765 isKnownNonNegative(getSCEV(U->getOperand(1))))
7766 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7769 case Instruction::GetElementPtr:
7770 return createNodeForGEP(cast<GEPOperator>(U));
7772 case Instruction::PHI:
7773 return createNodeForPHI(cast<PHINode>(U));
7775 case Instruction::Select:
7776 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
7779 case Instruction::Call:
7780 case Instruction::Invoke:
7781 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
7784 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7785 switch (II->getIntrinsicID()) {
7786 case Intrinsic::abs:
7788 getSCEV(II->getArgOperand(0)),
7789 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
7790 case Intrinsic::umax:
7791 LHS = getSCEV(II->getArgOperand(0));
7792 RHS = getSCEV(II->getArgOperand(1));
7793 return getUMaxExpr(LHS, RHS);
7794 case Intrinsic::umin:
7795 LHS = getSCEV(II->getArgOperand(0));
7796 RHS = getSCEV(II->getArgOperand(1));
7797 return getUMinExpr(LHS, RHS);
7798 case Intrinsic::smax:
7799 LHS = getSCEV(II->getArgOperand(0));
7800 RHS = getSCEV(II->getArgOperand(1));
7801 return getSMaxExpr(LHS, RHS);
7802 case Intrinsic::smin:
7803 LHS = getSCEV(II->getArgOperand(0));
7804 RHS = getSCEV(II->getArgOperand(1));
7805 return getSMinExpr(LHS, RHS);
7806 case Intrinsic::usub_sat: {
7807 const SCEV *X = getSCEV(II->getArgOperand(0));
7808 const SCEV *Y = getSCEV(II->getArgOperand(1));
7809 const SCEV *ClampedY = getUMinExpr(X, Y);
7810 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
7812 case Intrinsic::uadd_sat: {
7813 const SCEV *X = getSCEV(II->getArgOperand(0));
7814 const SCEV *Y = getSCEV(II->getArgOperand(1));
7815 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
7816 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
7818 case Intrinsic::start_loop_iterations:
7819 // A start_loop_iterations is just equivalent to the first operand for
7821 return getSCEV(II->getArgOperand(0));
7829 return getUnknown(V);
7832 //===----------------------------------------------------------------------===//
7833 // Iteration Count Computation Code
7836 const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
7838 if (isa<SCEVCouldNotCompute>(ExitCount))
7839 return getCouldNotCompute();
7841 auto *ExitCountType = ExitCount->getType();
7842 assert(ExitCountType->isIntegerTy());
7845 return getAddExpr(ExitCount, getOne(ExitCountType));
7847 auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
7848 1 + ExitCountType->getScalarSizeInBits());
7849 return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
7853 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
7857 ConstantInt *ExitConst = ExitCount->getValue();
7859 // Guard against huge trip counts.
7860 if (ExitConst->getValue().getActiveBits() > 32)
7863 // In case of integer overflow, this returns 0, which is correct.
7864 return ((unsigned)ExitConst->getZExtValue()) + 1;
7867 unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
7868 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
7869 return getConstantTripCount(ExitCount);
7873 ScalarEvolution::getSmallConstantTripCount(const Loop *L,
7874 const BasicBlock *ExitingBlock) {
7875 assert(ExitingBlock && "Must pass a non-null exiting block!");
7876 assert(L->isLoopExiting(ExitingBlock) &&
7877 "Exiting block must actually branch out of the loop!");
7878 const SCEVConstant *ExitCount =
7879 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
7880 return getConstantTripCount(ExitCount);
7883 unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
7884 const auto *MaxExitCount =
7885 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
7886 return getConstantTripCount(MaxExitCount);
7889 const SCEV *ScalarEvolution::getConstantMaxTripCountFromArray(const Loop *L) {
7890 // We can't infer from Array in Irregular Loop.
7891 // FIXME: It's hard to infer loop bound from array operated in Nested Loop.
7892 if (!L->isLoopSimplifyForm() || !L->isInnermost())
7893 return getCouldNotCompute();
7895 // FIXME: To make the scene more typical, we only analysis loops that have
7896 // one exiting block and that block must be the latch. To make it easier to
7897 // capture loops that have memory access and memory access will be executed
7898 // in each iteration.
7899 const BasicBlock *LoopLatch = L->getLoopLatch();
7900 assert(LoopLatch && "See defination of simplify form loop.");
7901 if (L->getExitingBlock() != LoopLatch)
7902 return getCouldNotCompute();
7904 const DataLayout &DL = getDataLayout();
7905 SmallVector<const SCEV *> InferCountColl;
7906 for (auto *BB : L->getBlocks()) {
7907 // Go here, we can know that Loop is a single exiting and simplified form
7908 // loop. Make sure that infer from Memory Operation in those BBs must be
7909 // executed in loop. First step, we can make sure that max execution time
7910 // of MemAccessBB in loop represents latch max excution time.
7911 // If MemAccessBB does not dom Latch, skip.
7915 // │Loop Header◄─────┐
7918 // ┌────────▼──┐ ┌─▼─────┐ │
7919 // │MemAccessBB│ │OtherBB│ │
7920 // └────────┬──┘ └─┬─────┘ │
7923 // │Loop Latch├─────┘
7927 if (!DT.dominates(BB, LoopLatch))
7930 for (Instruction &Inst : *BB) {
7931 // Find Memory Operation Instruction.
7932 auto *GEP = getLoadStorePointerOperand(&Inst);
7936 auto *ElemSize = dyn_cast<SCEVConstant>(getElementSize(&Inst));
7937 // Do not infer from scalar type, eg."ElemSize = sizeof()".
7941 // Use a existing polynomial recurrence on the trip count.
7942 auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(GEP));
7945 auto *ArrBase = dyn_cast<SCEVUnknown>(getPointerBase(AddRec));
7946 auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*this));
7947 if (!ArrBase || !Step)
7949 assert(isLoopInvariant(ArrBase, L) && "See addrec definition");
7951 // Only handle { %array + step },
7952 // FIXME: {(SCEVAddRecExpr) + step } could not be analysed here.
7953 if (AddRec->getStart() != ArrBase)
7956 // Memory operation pattern which have gaps.
7957 // Or repeat memory opreation.
7958 // And index of GEP wraps arround.
7959 if (Step->getAPInt().getActiveBits() > 32 ||
7960 Step->getAPInt().getZExtValue() !=
7961 ElemSize->getAPInt().getZExtValue() ||
7962 Step->isZero() || Step->getAPInt().isNegative())
7965 // Only infer from stack array which has certain size.
7966 // Make sure alloca instruction is not excuted in loop.
7967 AllocaInst *AllocateInst = dyn_cast<AllocaInst>(ArrBase->getValue());
7968 if (!AllocateInst || L->contains(AllocateInst->getParent()))
7971 // Make sure only handle normal array.
7972 auto *Ty = dyn_cast<ArrayType>(AllocateInst->getAllocatedType());
7973 auto *ArrSize = dyn_cast<ConstantInt>(AllocateInst->getArraySize());
7974 if (!Ty || !ArrSize || !ArrSize->isOne())
7977 // FIXME: Since gep indices are silently zext to the indexing type,
7978 // we will have a narrow gep index which wraps around rather than
7979 // increasing strictly, we shoule ensure that step is increasing
7980 // strictly by the loop iteration.
7981 // Now we can infer a max execution time by MemLength/StepLength.
7982 const SCEV *MemSize =
7983 getConstant(Step->getType(), DL.getTypeAllocSize(Ty));
7985 dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
7986 if (!MaxExeCount || MaxExeCount->getAPInt().getActiveBits() > 32)
7989 // If the loop reaches the maximum number of executions, we can not
7990 // access bytes starting outside the statically allocated size without
7991 // being immediate UB. But it is allowed to enter loop header one more
7993 auto *InferCount = dyn_cast<SCEVConstant>(
7994 getAddExpr(MaxExeCount, getOne(MaxExeCount->getType())));
7995 // Discard the maximum number of execution times under 32bits.
7996 if (!InferCount || InferCount->getAPInt().getActiveBits() > 32)
7999 InferCountColl.push_back(InferCount);
8003 if (InferCountColl.size() == 0)
8004 return getCouldNotCompute();
8006 return getUMinFromMismatchedTypes(InferCountColl);
8009 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
8010 SmallVector<BasicBlock *, 8> ExitingBlocks;
8011 L->getExitingBlocks(ExitingBlocks);
8013 Optional<unsigned> Res = None;
8014 for (auto *ExitingBB : ExitingBlocks) {
8015 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8018 Res = (unsigned)GreatestCommonDivisor64(*Res, Multiple);
8020 return Res.value_or(1);
8023 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
8024 const SCEV *ExitCount) {
8025 if (ExitCount == getCouldNotCompute())
8028 // Get the trip count
8029 const SCEV *TCExpr = getTripCountFromExitCount(ExitCount);
8031 const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
8033 // Attempt to factor more general cases. Returns the greatest power of
8034 // two divisor. If overflow happens, the trip count expression is still
8035 // divisible by the greatest power of 2 divisor returned.
8036 return 1U << std::min((uint32_t)31,
8037 GetMinTrailingZeros(applyLoopGuards(TCExpr, L)));
8039 ConstantInt *Result = TC->getValue();
8041 // Guard against huge trip counts (this requires checking
8042 // for zero to handle the case where the trip count == -1 and the
8044 if (!Result || Result->getValue().getActiveBits() > 32 ||
8045 Result->getValue().getActiveBits() == 0)
8048 return (unsigned)Result->getZExtValue();
8051 /// Returns the largest constant divisor of the trip count of this loop as a
8052 /// normal unsigned value, if possible. This means that the actual trip count is
8053 /// always a multiple of the returned value (don't forget the trip count could
8054 /// very well be zero as well!).
8056 /// Returns 1 if the trip count is unknown or not guaranteed to be the
8057 /// multiple of a constant (which is also the case if the trip count is simply
8058 /// constant, use getSmallConstantTripCount for that case), Will also return 1
8059 /// if the trip count is very large (>= 2^32).
8061 /// As explained in the comments for getSmallConstantTripCount, this assumes
8062 /// that control exits the loop via ExitingBlock.
8064 ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
8065 const BasicBlock *ExitingBlock) {
8066 assert(ExitingBlock && "Must pass a non-null exiting block!");
8067 assert(L->isLoopExiting(ExitingBlock) &&
8068 "Exiting block must actually branch out of the loop!");
8069 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8070 return getSmallConstantTripMultiple(L, ExitCount);
8073 const SCEV *ScalarEvolution::getExitCount(const Loop *L,
8074 const BasicBlock *ExitingBlock,
8075 ExitCountKind Kind) {
8078 case SymbolicMaximum:
8079 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8080 case ConstantMaximum:
8081 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8083 llvm_unreachable("Invalid ExitCountKind!");
8087 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
8088 SmallVector<const SCEVPredicate *, 4> &Preds) {
8089 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8092 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
8093 ExitCountKind Kind) {
8096 return getBackedgeTakenInfo(L).getExact(L, this);
8097 case ConstantMaximum:
8098 return getBackedgeTakenInfo(L).getConstantMax(this);
8099 case SymbolicMaximum:
8100 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8102 llvm_unreachable("Invalid ExitCountKind!");
8105 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
8106 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8109 /// Push PHI nodes in the header of the given loop onto the given Worklist.
8110 static void PushLoopPHIs(const Loop *L,
8111 SmallVectorImpl<Instruction *> &Worklist,
8112 SmallPtrSetImpl<Instruction *> &Visited) {
8113 BasicBlock *Header = L->getHeader();
8115 // Push all Loop-header PHIs onto the Worklist stack.
8116 for (PHINode &PN : Header->phis())
8117 if (Visited.insert(&PN).second)
8118 Worklist.push_back(&PN);
8121 const ScalarEvolution::BackedgeTakenInfo &
8122 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8123 auto &BTI = getBackedgeTakenInfo(L);
8124 if (BTI.hasFullInfo())
8127 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8130 return Pair.first->second;
8132 BackedgeTakenInfo Result =
8133 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8135 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8138 ScalarEvolution::BackedgeTakenInfo &
8139 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8140 // Initially insert an invalid entry for this loop. If the insertion
8141 // succeeds, proceed to actually compute a backedge-taken count and
8142 // update the value. The temporary CouldNotCompute value tells SCEV
8143 // code elsewhere that it shouldn't attempt to request a new
8144 // backedge-taken count, which could result in infinite recursion.
8145 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8146 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8148 return Pair.first->second;
8150 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8151 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8152 // must be cleared in this scope.
8153 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8155 // In product build, there are no usage of statistic.
8156 (void)NumTripCountsComputed;
8157 (void)NumTripCountsNotComputed;
8158 #if LLVM_ENABLE_STATS || !defined(NDEBUG)
8159 const SCEV *BEExact = Result.getExact(L, this);
8160 if (BEExact != getCouldNotCompute()) {
8161 assert(isLoopInvariant(BEExact, L) &&
8162 isLoopInvariant(Result.getConstantMax(this), L) &&
8163 "Computed backedge-taken count isn't loop invariant for loop!");
8164 ++NumTripCountsComputed;
8165 } else if (Result.getConstantMax(this) == getCouldNotCompute() &&
8166 isa<PHINode>(L->getHeader()->begin())) {
8167 // Only count loops that have phi nodes as not being computable.
8168 ++NumTripCountsNotComputed;
8170 #endif // LLVM_ENABLE_STATS || !defined(NDEBUG)
8172 // Now that we know more about the trip count for this loop, forget any
8173 // existing SCEV values for PHI nodes in this loop since they are only
8174 // conservative estimates made without the benefit of trip count
8175 // information. This invalidation is not necessary for correctness, and is
8176 // only done to produce more precise results.
8177 if (Result.hasAnyInfo()) {
8178 // Invalidate any expression using an addrec in this loop.
8179 SmallVector<const SCEV *, 8> ToForget;
8180 auto LoopUsersIt = LoopUsers.find(L);
8181 if (LoopUsersIt != LoopUsers.end())
8182 append_range(ToForget, LoopUsersIt->second);
8183 forgetMemoizedResults(ToForget);
8185 // Invalidate constant-evolved loop header phis.
8186 for (PHINode &PN : L->getHeader()->phis())
8187 ConstantEvolutionLoopExitValue.erase(&PN);
8190 // Re-lookup the insert position, since the call to
8191 // computeBackedgeTakenCount above could result in a
8192 // recusive call to getBackedgeTakenInfo (on a different
8193 // loop), which would invalidate the iterator computed
8195 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8198 void ScalarEvolution::forgetAllLoops() {
8199 // This method is intended to forget all info about loops. It should
8200 // invalidate caches as if the following happened:
8201 // - The trip counts of all loops have changed arbitrarily
8202 // - Every llvm::Value has been updated in place to produce a different
8204 BackedgeTakenCounts.clear();
8205 PredicatedBackedgeTakenCounts.clear();
8206 BECountUsers.clear();
8207 LoopPropertiesCache.clear();
8208 ConstantEvolutionLoopExitValue.clear();
8209 ValueExprMap.clear();
8210 ValuesAtScopes.clear();
8211 ValuesAtScopesUsers.clear();
8212 LoopDispositions.clear();
8213 BlockDispositions.clear();
8214 UnsignedRanges.clear();
8215 SignedRanges.clear();
8216 ExprValueMap.clear();
8218 MinTrailingZerosCache.clear();
8219 PredicatedSCEVRewrites.clear();
8222 void ScalarEvolution::forgetLoop(const Loop *L) {
8223 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8224 SmallVector<Instruction *, 32> Worklist;
8225 SmallPtrSet<Instruction *, 16> Visited;
8226 SmallVector<const SCEV *, 16> ToForget;
8228 // Iterate over all the loops and sub-loops to drop SCEV information.
8229 while (!LoopWorklist.empty()) {
8230 auto *CurrL = LoopWorklist.pop_back_val();
8232 // Drop any stored trip count value.
8233 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8234 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8236 // Drop information about predicated SCEV rewrites for this loop.
8237 for (auto I = PredicatedSCEVRewrites.begin();
8238 I != PredicatedSCEVRewrites.end();) {
8239 std::pair<const SCEV *, const Loop *> Entry = I->first;
8240 if (Entry.second == CurrL)
8241 PredicatedSCEVRewrites.erase(I++);
8246 auto LoopUsersItr = LoopUsers.find(CurrL);
8247 if (LoopUsersItr != LoopUsers.end()) {
8248 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8249 LoopUsersItr->second.end());
8252 // Drop information about expressions based on loop-header PHIs.
8253 PushLoopPHIs(CurrL, Worklist, Visited);
8255 while (!Worklist.empty()) {
8256 Instruction *I = Worklist.pop_back_val();
8258 ValueExprMapType::iterator It =
8259 ValueExprMap.find_as(static_cast<Value *>(I));
8260 if (It != ValueExprMap.end()) {
8261 eraseValueFromMap(It->first);
8262 ToForget.push_back(It->second);
8263 if (PHINode *PN = dyn_cast<PHINode>(I))
8264 ConstantEvolutionLoopExitValue.erase(PN);
8267 PushDefUseChildren(I, Worklist, Visited);
8270 LoopPropertiesCache.erase(CurrL);
8271 // Forget all contained loops too, to avoid dangling entries in the
8272 // ValuesAtScopes map.
8273 LoopWorklist.append(CurrL->begin(), CurrL->end());
8275 forgetMemoizedResults(ToForget);
8278 void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
8279 forgetLoop(L->getOutermostLoop());
8282 void ScalarEvolution::forgetValue(Value *V) {
8283 Instruction *I = dyn_cast<Instruction>(V);
8286 // Drop information about expressions based on loop-header PHIs.
8287 SmallVector<Instruction *, 16> Worklist;
8288 SmallPtrSet<Instruction *, 8> Visited;
8289 SmallVector<const SCEV *, 8> ToForget;
8290 Worklist.push_back(I);
8293 while (!Worklist.empty()) {
8294 I = Worklist.pop_back_val();
8295 ValueExprMapType::iterator It =
8296 ValueExprMap.find_as(static_cast<Value *>(I));
8297 if (It != ValueExprMap.end()) {
8298 eraseValueFromMap(It->first);
8299 ToForget.push_back(It->second);
8300 if (PHINode *PN = dyn_cast<PHINode>(I))
8301 ConstantEvolutionLoopExitValue.erase(PN);
8304 PushDefUseChildren(I, Worklist, Visited);
8306 forgetMemoizedResults(ToForget);
8309 void ScalarEvolution::forgetLoopDispositions(const Loop *L) {
8310 LoopDispositions.clear();
8313 /// Get the exact loop backedge taken count considering all loop exits. A
8314 /// computable result can only be returned for loops with all exiting blocks
8315 /// dominating the latch. howFarToZero assumes that the limit of each loop test
8316 /// is never skipped. This is a valid assumption as long as the loop exits via
8317 /// that test. For precise results, it is the caller's responsibility to specify
8318 /// the relevant loop exiting block using getExact(ExitingBlock, SE).
8320 ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8321 SmallVector<const SCEVPredicate *, 4> *Preds) const {
8322 // If any exits were not computable, the loop is not computable.
8323 if (!isComplete() || ExitNotTaken.empty())
8324 return SE->getCouldNotCompute();
8326 const BasicBlock *Latch = L->getLoopLatch();
8327 // All exiting blocks we have collected must dominate the only backedge.
8329 return SE->getCouldNotCompute();
8331 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8332 // count is simply a minimum out of all these calculated exit counts.
8333 SmallVector<const SCEV *, 2> Ops;
8334 for (auto &ENT : ExitNotTaken) {
8335 const SCEV *BECount = ENT.ExactNotTaken;
8336 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8337 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8338 "We should only have known counts for exiting blocks that dominate "
8341 Ops.push_back(BECount);
8344 for (auto *P : ENT.Predicates)
8345 Preds->push_back(P);
8347 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8348 "Predicate should be always true!");
8351 // If an earlier exit exits on the first iteration (exit count zero), then
8352 // a later poison exit count should not propagate into the result. This are
8353 // exactly the semantics provided by umin_seq.
8354 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8357 /// Get the exact not taken count for this loop exit.
8359 ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8360 ScalarEvolution *SE) const {
8361 for (auto &ENT : ExitNotTaken)
8362 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8363 return ENT.ExactNotTaken;
8365 return SE->getCouldNotCompute();
8368 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8369 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8370 for (auto &ENT : ExitNotTaken)
8371 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8372 return ENT.MaxNotTaken;
8374 return SE->getCouldNotCompute();
8377 /// getConstantMax - Get the constant max backedge taken count for the loop.
8379 ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8380 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8381 return !ENT.hasAlwaysTruePredicate();
8384 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8385 return SE->getCouldNotCompute();
8387 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8388 isa<SCEVConstant>(getConstantMax())) &&
8389 "No point in having a non-constant max backedge taken count!");
8390 return getConstantMax();
8394 ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
8395 ScalarEvolution *SE) {
8397 SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
8401 bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8402 ScalarEvolution *SE) const {
8403 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8404 return !ENT.hasAlwaysTruePredicate();
8406 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8409 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
8410 : ExitLimit(E, E, false, None) {
8413 ScalarEvolution::ExitLimit::ExitLimit(
8414 const SCEV *E, const SCEV *M, bool MaxOrZero,
8415 ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
8416 : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
8417 // If we prove the max count is zero, so is the symbolic bound. This happens
8418 // in practice due to differences in a) how context sensitive we've chosen
8419 // to be and b) how we reason about bounds impied by UB.
8420 if (MaxNotTaken->isZero())
8421 ExactNotTaken = MaxNotTaken;
8423 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8424 !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
8425 "Exact is not allowed to be less precise than Max");
8426 assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
8427 isa<SCEVConstant>(MaxNotTaken)) &&
8428 "No point in having a non-constant max backedge taken count!");
8429 for (auto *PredSet : PredSetList)
8430 for (auto *P : *PredSet)
8432 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8433 "Backedge count should be int");
8434 assert((isa<SCEVCouldNotCompute>(M) || !M->getType()->isPointerTy()) &&
8435 "Max backedge count should be int");
8438 ScalarEvolution::ExitLimit::ExitLimit(
8439 const SCEV *E, const SCEV *M, bool MaxOrZero,
8440 const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
8441 : ExitLimit(E, M, MaxOrZero, {&PredSet}) {
8444 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
8446 : ExitLimit(E, M, MaxOrZero, None) {
8449 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8450 /// computable exit into a persistent ExitNotTakenInfo array.
8451 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8452 ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts,
8453 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8454 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8455 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8457 ExitNotTaken.reserve(ExitCounts.size());
8459 ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
8460 [&](const EdgeExitInfo &EEI) {
8461 BasicBlock *ExitBB = EEI.first;
8462 const ExitLimit &EL = EEI.second;
8463 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
8466 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8467 isa<SCEVConstant>(ConstantMax)) &&
8468 "No point in having a non-constant max backedge taken count!");
8471 /// Compute the number of times the backedge of the specified loop will execute.
8472 ScalarEvolution::BackedgeTakenInfo
8473 ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8474 bool AllowPredicates) {
8475 SmallVector<BasicBlock *, 8> ExitingBlocks;
8476 L->getExitingBlocks(ExitingBlocks);
8478 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8480 SmallVector<EdgeExitInfo, 4> ExitCounts;
8481 bool CouldComputeBECount = true;
8482 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8483 const SCEV *MustExitMaxBECount = nullptr;
8484 const SCEV *MayExitMaxBECount = nullptr;
8485 bool MustExitMaxOrZero = false;
8487 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8488 // and compute maxBECount.
8489 // Do a union of all the predicates here.
8490 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
8491 BasicBlock *ExitBB = ExitingBlocks[i];
8493 // We canonicalize untaken exits to br (constant), ignore them so that
8494 // proving an exit untaken doesn't negatively impact our ability to reason
8495 // about the loop as whole.
8496 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8497 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8498 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8499 if (ExitIfTrue == CI->isZero())
8503 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
8505 assert((AllowPredicates || EL.Predicates.empty()) &&
8506 "Predicated exit limit when predicates are not allowed!");
8508 // 1. For each exit that can be computed, add an entry to ExitCounts.
8509 // CouldComputeBECount is true only if all exits can be computed.
8510 if (EL.ExactNotTaken == getCouldNotCompute())
8511 // We couldn't compute an exact value for this exit, so
8512 // we won't be able to compute an exact value for the loop.
8513 CouldComputeBECount = false;
8515 ExitCounts.emplace_back(ExitBB, EL);
8517 // 2. Derive the loop's MaxBECount from each exit's max number of
8518 // non-exiting iterations. Partition the loop exits into two kinds:
8519 // LoopMustExits and LoopMayExits.
8521 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8522 // is a LoopMayExit. If any computable LoopMustExit is found, then
8523 // MaxBECount is the minimum EL.MaxNotTaken of computable
8524 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8525 // EL.MaxNotTaken, where CouldNotCompute is considered greater than any
8526 // computable EL.MaxNotTaken.
8527 if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
8528 DT.dominates(ExitBB, Latch)) {
8529 if (!MustExitMaxBECount) {
8530 MustExitMaxBECount = EL.MaxNotTaken;
8531 MustExitMaxOrZero = EL.MaxOrZero;
8533 MustExitMaxBECount =
8534 getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
8536 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8537 if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
8538 MayExitMaxBECount = EL.MaxNotTaken;
8541 getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
8545 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8546 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8547 // The loop backedge will be taken the maximum or zero times if there's
8548 // a single exit that must be taken the maximum or zero times.
8549 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8551 // Remember which SCEVs are used in exit limits for invalidation purposes.
8552 // We only care about non-constant SCEVs here, so we can ignore EL.MaxNotTaken
8553 // and MaxBECount, which must be SCEVConstant.
8554 for (const auto &Pair : ExitCounts)
8555 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8556 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8557 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8558 MaxBECount, MaxOrZero);
8561 ScalarEvolution::ExitLimit
8562 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8563 bool AllowPredicates) {
8564 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8565 // If our exiting block does not dominate the latch, then its connection with
8566 // loop's exit limit may be far from trivial.
8567 const BasicBlock *Latch = L->getLoopLatch();
8568 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8569 return getCouldNotCompute();
8571 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
8572 Instruction *Term = ExitingBlock->getTerminator();
8573 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8574 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8575 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8576 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8577 "It should have one successor in loop and one exit block!");
8578 // Proceed to the next level to examine the exit condition expression.
8579 return computeExitLimitFromCond(
8580 L, BI->getCondition(), ExitIfTrue,
8581 /*ControlsExit=*/IsOnlyExit, AllowPredicates);
8584 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8585 // For switch, make sure that there is a single exit from the loop.
8586 BasicBlock *Exit = nullptr;
8587 for (auto *SBB : successors(ExitingBlock))
8588 if (!L->contains(SBB)) {
8589 if (Exit) // Multiple exit successors.
8590 return getCouldNotCompute();
8593 assert(Exit && "Exiting block must have at least one exit");
8594 return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
8595 /*ControlsExit=*/IsOnlyExit);
8598 return getCouldNotCompute();
8601 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
8602 const Loop *L, Value *ExitCond, bool ExitIfTrue,
8603 bool ControlsExit, bool AllowPredicates) {
8604 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8605 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8606 ControlsExit, AllowPredicates);
8609 Optional<ScalarEvolution::ExitLimit>
8610 ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8611 bool ExitIfTrue, bool ControlsExit,
8612 bool AllowPredicates) {
8614 (void)this->ExitIfTrue;
8615 (void)this->AllowPredicates;
8617 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8618 this->AllowPredicates == AllowPredicates &&
8619 "Variance in assumed invariant key components!");
8620 auto Itr = TripCountMap.find({ExitCond, ControlsExit});
8621 if (Itr == TripCountMap.end())
8626 void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8629 bool AllowPredicates,
8630 const ExitLimit &EL) {
8631 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8632 this->AllowPredicates == AllowPredicates &&
8633 "Variance in assumed invariant key components!");
8635 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL});
8636 assert(InsertResult.second && "Expected successful insertion!");
8641 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8642 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8643 bool ControlsExit, bool AllowPredicates) {
8646 Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
8649 ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
8650 ControlsExit, AllowPredicates);
8651 Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
8655 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8656 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8657 bool ControlsExit, bool AllowPredicates) {
8658 // Handle BinOp conditions (And, Or).
8659 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8660 Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
8661 return *LimitFromBinOp;
8663 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8664 // Proceed to the next level to examine the icmp.
8665 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8667 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
8668 if (EL.hasFullInfo() || !AllowPredicates)
8671 // Try again, but use SCEV predicates this time.
8672 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
8673 /*AllowPredicates=*/true);
8676 // Check for a constant condition. These are normally stripped out by
8677 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8678 // preserve the CFG and is temporarily leaving constant conditions
8680 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8681 if (ExitIfTrue == !CI->getZExtValue())
8682 // The backedge is always taken.
8683 return getCouldNotCompute();
8685 // The backedge is never taken.
8686 return getZero(CI->getType());
8689 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8690 // with a constant step, we can form an equivalent icmp predicate and figure
8691 // out how many iterations will be taken before we exit.
8692 const WithOverflowInst *WO;
8694 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8695 match(WO->getRHS(), m_APInt(C))) {
8697 ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
8698 WO->getNoWrapKind());
8699 CmpInst::Predicate Pred;
8700 APInt NewRHSC, Offset;
8701 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8703 Pred = ICmpInst::getInversePredicate(Pred);
8704 auto *LHS = getSCEV(WO->getLHS());
8706 LHS = getAddExpr(LHS, getConstant(Offset));
8707 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8708 ControlsExit, AllowPredicates);
8709 if (EL.hasAnyInfo()) return EL;
8712 // If it's not an integer or pointer comparison then compute it the hard way.
8713 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8716 Optional<ScalarEvolution::ExitLimit>
8717 ScalarEvolution::computeExitLimitFromCondFromBinOp(
8718 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8719 bool ControlsExit, bool AllowPredicates) {
8720 // Check if the controlling expression for this loop is an And or Or.
8723 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
8725 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
8730 // EitherMayExit is true in these two cases:
8731 // br (and Op0 Op1), loop, exit
8732 // br (or Op0 Op1), exit, loop
8733 bool EitherMayExit = IsAnd ^ ExitIfTrue;
8734 ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue,
8735 ControlsExit && !EitherMayExit,
8737 ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue,
8738 ControlsExit && !EitherMayExit,
8741 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
8742 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
8743 if (isa<ConstantInt>(Op1))
8744 return Op1 == NeutralElement ? EL0 : EL1;
8745 if (isa<ConstantInt>(Op0))
8746 return Op0 == NeutralElement ? EL1 : EL0;
8748 const SCEV *BECount = getCouldNotCompute();
8749 const SCEV *MaxBECount = getCouldNotCompute();
8750 if (EitherMayExit) {
8751 // Both conditions must be same for the loop to continue executing.
8752 // Choose the less conservative count.
8753 if (EL0.ExactNotTaken != getCouldNotCompute() &&
8754 EL1.ExactNotTaken != getCouldNotCompute()) {
8755 BECount = getUMinFromMismatchedTypes(
8756 EL0.ExactNotTaken, EL1.ExactNotTaken,
8757 /*Sequential=*/!isa<BinaryOperator>(ExitCond));
8759 if (EL0.MaxNotTaken == getCouldNotCompute())
8760 MaxBECount = EL1.MaxNotTaken;
8761 else if (EL1.MaxNotTaken == getCouldNotCompute())
8762 MaxBECount = EL0.MaxNotTaken;
8764 MaxBECount = getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
8766 // Both conditions must be same at the same time for the loop to exit.
8767 // For now, be conservative.
8768 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
8769 BECount = EL0.ExactNotTaken;
8772 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
8773 // to be more aggressive when computing BECount than when computing
8774 // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
8775 // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
8777 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
8778 !isa<SCEVCouldNotCompute>(BECount))
8779 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
8781 return ExitLimit(BECount, MaxBECount, false,
8782 { &EL0.Predicates, &EL1.Predicates });
8785 ScalarEvolution::ExitLimit
8786 ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
8790 bool AllowPredicates) {
8791 // If the condition was exit on true, convert the condition to exit on false
8792 ICmpInst::Predicate Pred;
8794 Pred = ExitCond->getPredicate();
8796 Pred = ExitCond->getInversePredicate();
8797 const ICmpInst::Predicate OriginalPred = Pred;
8799 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
8800 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
8802 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsExit,
8804 if (EL.hasAnyInfo()) return EL;
8806 auto *ExhaustiveCount =
8807 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8809 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
8810 return ExhaustiveCount;
8812 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
8813 ExitCond->getOperand(1), L, OriginalPred);
8815 ScalarEvolution::ExitLimit
8816 ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
8817 ICmpInst::Predicate Pred,
8818 const SCEV *LHS, const SCEV *RHS,
8820 bool AllowPredicates) {
8822 // Try to evaluate any dependencies out of the loop.
8823 LHS = getSCEVAtScope(LHS, L);
8824 RHS = getSCEVAtScope(RHS, L);
8826 // At this point, we would like to compute how many iterations of the
8827 // loop the predicate will return true for these inputs.
8828 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
8829 // If there is a loop-invariant, force it into the RHS.
8830 std::swap(LHS, RHS);
8831 Pred = ICmpInst::getSwappedPredicate(Pred);
8834 bool ControllingFiniteLoop =
8835 ControlsExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L);
8836 // Simplify the operands before analyzing them.
8837 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0,
8838 (EnableFiniteLoopControl ? ControllingFiniteLoop
8841 // If we have a comparison of a chrec against a constant, try to use value
8842 // ranges to answer this query.
8843 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
8844 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
8845 if (AddRec->getLoop() == L) {
8846 // Form the constant range.
8847 ConstantRange CompRange =
8848 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
8850 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
8851 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
8854 // If this loop must exit based on this condition (or execute undefined
8855 // behaviour), and we can prove the test sequence produced must repeat
8856 // the same values on self-wrap of the IV, then we can infer that IV
8857 // doesn't self wrap because if it did, we'd have an infinite (undefined)
8859 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
8860 // TODO: We can peel off any functions which are invertible *in L*. Loop
8861 // invariant terms are effectively constants for our purposes here.
8862 auto *InnerLHS = LHS;
8863 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
8864 InnerLHS = ZExt->getOperand();
8865 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
8866 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
8867 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
8868 StrideC && StrideC->getAPInt().isPowerOf2()) {
8869 auto Flags = AR->getNoWrapFlags();
8870 Flags = setFlags(Flags, SCEV::FlagNW);
8871 SmallVector<const SCEV*> Operands{AR->operands()};
8872 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
8873 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
8879 case ICmpInst::ICMP_NE: { // while (X != Y)
8880 // Convert to: while (X-Y != 0)
8881 if (LHS->getType()->isPointerTy()) {
8882 LHS = getLosslessPtrToIntExpr(LHS);
8883 if (isa<SCEVCouldNotCompute>(LHS))
8886 if (RHS->getType()->isPointerTy()) {
8887 RHS = getLosslessPtrToIntExpr(RHS);
8888 if (isa<SCEVCouldNotCompute>(RHS))
8891 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
8893 if (EL.hasAnyInfo()) return EL;
8896 case ICmpInst::ICMP_EQ: { // while (X == Y)
8897 // Convert to: while (X-Y == 0)
8898 if (LHS->getType()->isPointerTy()) {
8899 LHS = getLosslessPtrToIntExpr(LHS);
8900 if (isa<SCEVCouldNotCompute>(LHS))
8903 if (RHS->getType()->isPointerTy()) {
8904 RHS = getLosslessPtrToIntExpr(RHS);
8905 if (isa<SCEVCouldNotCompute>(RHS))
8908 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
8909 if (EL.hasAnyInfo()) return EL;
8912 case ICmpInst::ICMP_SLT:
8913 case ICmpInst::ICMP_ULT: { // while (X < Y)
8914 bool IsSigned = Pred == ICmpInst::ICMP_SLT;
8915 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
8917 if (EL.hasAnyInfo()) return EL;
8920 case ICmpInst::ICMP_SGT:
8921 case ICmpInst::ICMP_UGT: { // while (X > Y)
8922 bool IsSigned = Pred == ICmpInst::ICMP_SGT;
8924 howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
8926 if (EL.hasAnyInfo()) return EL;
8933 return getCouldNotCompute();
8936 ScalarEvolution::ExitLimit
8937 ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
8939 BasicBlock *ExitingBlock,
8940 bool ControlsExit) {
8941 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
8943 // Give up if the exit is the default dest of a switch.
8944 if (Switch->getDefaultDest() == ExitingBlock)
8945 return getCouldNotCompute();
8947 assert(L->contains(Switch->getDefaultDest()) &&
8948 "Default case must not exit the loop!");
8949 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
8950 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
8952 // while (X != Y) --> while (X-Y != 0)
8953 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
8954 if (EL.hasAnyInfo())
8957 return getCouldNotCompute();
8960 static ConstantInt *
8961 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
8962 ScalarEvolution &SE) {
8963 const SCEV *InVal = SE.getConstant(C);
8964 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
8965 assert(isa<SCEVConstant>(Val) &&
8966 "Evaluation of SCEV at constant didn't fold correctly?");
8967 return cast<SCEVConstant>(Val)->getValue();
8970 ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
8971 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
8972 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
8974 return getCouldNotCompute();
8976 const BasicBlock *Latch = L->getLoopLatch();
8978 return getCouldNotCompute();
8980 const BasicBlock *Predecessor = L->getLoopPredecessor();
8982 return getCouldNotCompute();
8984 // Return true if V is of the form "LHS `shift_op` <positive constant>".
8985 // Return LHS in OutLHS and shift_opt in OutOpCode.
8986 auto MatchPositiveShift =
8987 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
8989 using namespace PatternMatch;
8991 ConstantInt *ShiftAmt;
8992 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
8993 OutOpCode = Instruction::LShr;
8994 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
8995 OutOpCode = Instruction::AShr;
8996 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
8997 OutOpCode = Instruction::Shl;
9001 return ShiftAmt->getValue().isStrictlyPositive();
9004 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9007 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9008 // %iv.shifted = lshr i32 %iv, <positive constant>
9010 // Return true on a successful match. Return the corresponding PHI node (%iv
9011 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9012 auto MatchShiftRecurrence =
9013 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9014 Optional<Instruction::BinaryOps> PostShiftOpCode;
9017 Instruction::BinaryOps OpC;
9020 // If we encounter a shift instruction, "peel off" the shift operation,
9021 // and remember that we did so. Later when we inspect %iv's backedge
9022 // value, we will make sure that the backedge value uses the same
9025 // Note: the peeled shift operation does not have to be the same
9026 // instruction as the one feeding into the PHI's backedge value. We only
9027 // really care about it being the same *kind* of shift instruction --
9028 // that's all that is required for our later inferences to hold.
9029 if (MatchPositiveShift(LHS, V, OpC)) {
9030 PostShiftOpCode = OpC;
9035 PNOut = dyn_cast<PHINode>(LHS);
9036 if (!PNOut || PNOut->getParent() != L->getHeader())
9039 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9043 // The backedge value for the PHI node must be a shift by a positive
9045 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9047 // of the PHI node itself
9050 // and the kind of shift should be match the kind of shift we peeled
9052 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9056 Instruction::BinaryOps OpCode;
9057 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9058 return getCouldNotCompute();
9060 const DataLayout &DL = getDataLayout();
9062 // The key rationale for this optimization is that for some kinds of shift
9063 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9064 // within a finite number of iterations. If the condition guarding the
9065 // backedge (in the sense that the backedge is taken if the condition is true)
9066 // is false for the value the shift recurrence stabilizes to, then we know
9067 // that the backedge is taken only a finite number of times.
9069 ConstantInt *StableValue = nullptr;
9072 llvm_unreachable("Impossible case!");
9074 case Instruction::AShr: {
9075 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9076 // bitwidth(K) iterations.
9077 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9078 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9079 Predecessor->getTerminator(), &DT);
9080 auto *Ty = cast<IntegerType>(RHS->getType());
9081 if (Known.isNonNegative())
9082 StableValue = ConstantInt::get(Ty, 0);
9083 else if (Known.isNegative())
9084 StableValue = ConstantInt::get(Ty, -1, true);
9086 return getCouldNotCompute();
9090 case Instruction::LShr:
9091 case Instruction::Shl:
9092 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9093 // stabilize to 0 in at most bitwidth(K) iterations.
9094 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9099 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9100 assert(Result->getType()->isIntegerTy(1) &&
9101 "Otherwise cannot be an operand to a branch instruction");
9103 if (Result->isZeroValue()) {
9104 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9105 const SCEV *UpperBound =
9106 getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
9107 return ExitLimit(getCouldNotCompute(), UpperBound, false);
9110 return getCouldNotCompute();
9113 /// Return true if we can constant fold an instruction of the specified type,
9114 /// assuming that all operands were constants.
9115 static bool CanConstantFold(const Instruction *I) {
9116 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9117 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9118 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9121 if (const CallInst *CI = dyn_cast<CallInst>(I))
9122 if (const Function *F = CI->getCalledFunction())
9123 return canConstantFoldCallTo(CI, F);
9127 /// Determine whether this instruction can constant evolve within this loop
9128 /// assuming its operands can all constant evolve.
9129 static bool canConstantEvolve(Instruction *I, const Loop *L) {
9130 // An instruction outside of the loop can't be derived from a loop PHI.
9131 if (!L->contains(I)) return false;
9133 if (isa<PHINode>(I)) {
9134 // We don't currently keep track of the control flow needed to evaluate
9135 // PHIs, so we cannot handle PHIs inside of loops.
9136 return L->getHeader() == I->getParent();
9139 // If we won't be able to constant fold this expression even if the operands
9140 // are constants, bail early.
9141 return CanConstantFold(I);
9144 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9145 /// recursing through each instruction operand until reaching a loop header phi.
9147 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
9148 DenseMap<Instruction *, PHINode *> &PHIMap,
9150 if (Depth > MaxConstantEvolvingDepth)
9153 // Otherwise, we can evaluate this instruction if all of its operands are
9154 // constant or derived from a PHI node themselves.
9155 PHINode *PHI = nullptr;
9156 for (Value *Op : UseInst->operands()) {
9157 if (isa<Constant>(Op)) continue;
9159 Instruction *OpInst = dyn_cast<Instruction>(Op);
9160 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9162 PHINode *P = dyn_cast<PHINode>(OpInst);
9164 // If this operand is already visited, reuse the prior result.
9165 // We may have P != PHI if this is the deepest point at which the
9166 // inconsistent paths meet.
9167 P = PHIMap.lookup(OpInst);
9169 // Recurse and memoize the results, whether a phi is found or not.
9170 // This recursive call invalidates pointers into PHIMap.
9171 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9175 return nullptr; // Not evolving from PHI
9176 if (PHI && PHI != P)
9177 return nullptr; // Evolving from multiple different PHIs.
9180 // This is a expression evolving from a constant PHI!
9184 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9185 /// in the loop that V is derived from. We allow arbitrary operations along the
9186 /// way, but the operands of an operation must either be constants or a value
9187 /// derived from a constant PHI. If this expression does not fit with these
9188 /// constraints, return null.
9189 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
9190 Instruction *I = dyn_cast<Instruction>(V);
9191 if (!I || !canConstantEvolve(I, L)) return nullptr;
9193 if (PHINode *PN = dyn_cast<PHINode>(I))
9196 // Record non-constant instructions contained by the loop.
9197 DenseMap<Instruction *, PHINode *> PHIMap;
9198 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9201 /// EvaluateExpression - Given an expression that passes the
9202 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9203 /// in the loop has the value PHIVal. If we can't fold this expression for some
9204 /// reason, return null.
9205 static Constant *EvaluateExpression(Value *V, const Loop *L,
9206 DenseMap<Instruction *, Constant *> &Vals,
9207 const DataLayout &DL,
9208 const TargetLibraryInfo *TLI) {
9209 // Convenient constant check, but redundant for recursive calls.
9210 if (Constant *C = dyn_cast<Constant>(V)) return C;
9211 Instruction *I = dyn_cast<Instruction>(V);
9212 if (!I) return nullptr;
9214 if (Constant *C = Vals.lookup(I)) return C;
9216 // An instruction inside the loop depends on a value outside the loop that we
9217 // weren't given a mapping for, or a value such as a call inside the loop.
9218 if (!canConstantEvolve(I, L)) return nullptr;
9220 // An unmapped PHI can be due to a branch or another loop inside this loop,
9221 // or due to this not being the initial iteration through a loop where we
9222 // couldn't compute the evolution of this particular PHI last time.
9223 if (isa<PHINode>(I)) return nullptr;
9225 std::vector<Constant*> Operands(I->getNumOperands());
9227 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9228 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9230 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9231 if (!Operands[i]) return nullptr;
9234 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9236 if (!C) return nullptr;
9240 return ConstantFoldInstOperands(I, Operands, DL, TLI);
9244 // If every incoming value to PN except the one for BB is a specific Constant,
9245 // return that, else return nullptr.
9246 static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
9247 Constant *IncomingVal = nullptr;
9249 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9250 if (PN->getIncomingBlock(i) == BB)
9253 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9257 if (IncomingVal != CurrentVal) {
9260 IncomingVal = CurrentVal;
9267 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9268 /// in the header of its containing loop, we know the loop executes a
9269 /// constant number of times, and the PHI node is just a recurrence
9270 /// involving constants, fold it.
9272 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9275 auto I = ConstantEvolutionLoopExitValue.find(PN);
9276 if (I != ConstantEvolutionLoopExitValue.end())
9279 if (BEs.ugt(MaxBruteForceIterations))
9280 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9282 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9284 DenseMap<Instruction *, Constant *> CurrentIterVals;
9285 BasicBlock *Header = L->getHeader();
9286 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9288 BasicBlock *Latch = L->getLoopLatch();
9292 for (PHINode &PHI : Header->phis()) {
9293 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9294 CurrentIterVals[&PHI] = StartCST;
9296 if (!CurrentIterVals.count(PN))
9297 return RetVal = nullptr;
9299 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9301 // Execute the loop symbolically to determine the exit value.
9302 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9303 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9305 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9306 unsigned IterationNum = 0;
9307 const DataLayout &DL = getDataLayout();
9308 for (; ; ++IterationNum) {
9309 if (IterationNum == NumIterations)
9310 return RetVal = CurrentIterVals[PN]; // Got exit value!
9312 // Compute the value of the PHIs for the next iteration.
9313 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9314 DenseMap<Instruction *, Constant *> NextIterVals;
9316 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9318 return nullptr; // Couldn't evaluate!
9319 NextIterVals[PN] = NextPHI;
9321 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9323 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9324 // cease to be able to evaluate one of them or if they stop evolving,
9325 // because that doesn't necessarily prevent us from computing PN.
9326 SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
9327 for (const auto &I : CurrentIterVals) {
9328 PHINode *PHI = dyn_cast<PHINode>(I.first);
9329 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9330 PHIsToCompute.emplace_back(PHI, I.second);
9332 // We use two distinct loops because EvaluateExpression may invalidate any
9333 // iterators into CurrentIterVals.
9334 for (const auto &I : PHIsToCompute) {
9335 PHINode *PHI = I.first;
9336 Constant *&NextPHI = NextIterVals[PHI];
9337 if (!NextPHI) { // Not already computed.
9338 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9339 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9341 if (NextPHI != I.second)
9342 StoppedEvolving = false;
9345 // If all entries in CurrentIterVals == NextIterVals then we can stop
9346 // iterating, the loop can't continue to change.
9347 if (StoppedEvolving)
9348 return RetVal = CurrentIterVals[PN];
9350 CurrentIterVals.swap(NextIterVals);
9354 const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9357 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9358 if (!PN) return getCouldNotCompute();
9360 // If the loop is canonicalized, the PHI will have exactly two entries.
9361 // That's the only form we support here.
9362 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9364 DenseMap<Instruction *, Constant *> CurrentIterVals;
9365 BasicBlock *Header = L->getHeader();
9366 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9368 BasicBlock *Latch = L->getLoopLatch();
9369 assert(Latch && "Should follow from NumIncomingValues == 2!");
9371 for (PHINode &PHI : Header->phis()) {
9372 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9373 CurrentIterVals[&PHI] = StartCST;
9375 if (!CurrentIterVals.count(PN))
9376 return getCouldNotCompute();
9378 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9379 // the loop symbolically to determine when the condition gets a value of
9381 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9382 const DataLayout &DL = getDataLayout();
9383 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9384 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9385 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9387 // Couldn't symbolically evaluate.
9388 if (!CondVal) return getCouldNotCompute();
9390 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9391 ++NumBruteForceTripCountsComputed;
9392 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9395 // Update all the PHI nodes for the next iteration.
9396 DenseMap<Instruction *, Constant *> NextIterVals;
9398 // Create a list of which PHIs we need to compute. We want to do this before
9399 // calling EvaluateExpression on them because that may invalidate iterators
9400 // into CurrentIterVals.
9401 SmallVector<PHINode *, 8> PHIsToCompute;
9402 for (const auto &I : CurrentIterVals) {
9403 PHINode *PHI = dyn_cast<PHINode>(I.first);
9404 if (!PHI || PHI->getParent() != Header) continue;
9405 PHIsToCompute.push_back(PHI);
9407 for (PHINode *PHI : PHIsToCompute) {
9408 Constant *&NextPHI = NextIterVals[PHI];
9409 if (NextPHI) continue; // Already computed!
9411 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9412 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9414 CurrentIterVals.swap(NextIterVals);
9417 // Too many iterations were needed to evaluate.
9418 return getCouldNotCompute();
9421 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9422 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
9424 // Check to see if we've folded this expression at this loop before.
9425 for (auto &LS : Values)
9427 return LS.second ? LS.second : V;
9429 Values.emplace_back(L, nullptr);
9431 // Otherwise compute it.
9432 const SCEV *C = computeSCEVAtScope(V, L);
9433 for (auto &LS : reverse(ValuesAtScopes[V]))
9434 if (LS.first == L) {
9436 if (!isa<SCEVConstant>(C))
9437 ValuesAtScopesUsers[C].push_back({L, V});
9443 /// This builds up a Constant using the ConstantExpr interface. That way, we
9444 /// will return Constants for objects which aren't represented by a
9445 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9446 /// Returns NULL if the SCEV isn't representable as a Constant.
9447 static Constant *BuildConstantFromSCEV(const SCEV *V) {
9448 switch (V->getSCEVType()) {
9449 case scCouldNotCompute:
9453 return cast<SCEVConstant>(V)->getValue();
9455 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9456 case scSignExtend: {
9457 const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
9458 if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
9459 return ConstantExpr::getSExt(CastOp, SS->getType());
9462 case scZeroExtend: {
9463 const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
9464 if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
9465 return ConstantExpr::getZExt(CastOp, SZ->getType());
9469 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9470 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9471 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9476 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9477 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9478 return ConstantExpr::getTrunc(CastOp, ST->getType());
9482 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9483 Constant *C = nullptr;
9484 for (const SCEV *Op : SA->operands()) {
9485 Constant *OpC = BuildConstantFromSCEV(Op);
9492 assert(!C->getType()->isPointerTy() &&
9493 "Can only have one pointer, and it must be last");
9494 if (auto *PT = dyn_cast<PointerType>(OpC->getType())) {
9495 // The offsets have been converted to bytes. We can add bytes to an
9496 // i8* by GEP with the byte count in the first index.
9498 Type::getInt8PtrTy(PT->getContext(), PT->getAddressSpace());
9499 OpC = ConstantExpr::getBitCast(OpC, DestPtrTy);
9500 C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()),
9503 C = ConstantExpr::getAdd(C, OpC);
9509 const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
9510 Constant *C = nullptr;
9511 for (const SCEV *Op : SM->operands()) {
9512 assert(!Op->getType()->isPointerTy() && "Can't multiply pointers");
9513 Constant *OpC = BuildConstantFromSCEV(Op);
9516 C = C ? ConstantExpr::getMul(C, OpC) : OpC;
9521 const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
9522 if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
9523 if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
9524 if (LHS->getType() == RHS->getType())
9525 return ConstantExpr::getUDiv(LHS, RHS);
9532 case scSequentialUMinExpr:
9533 return nullptr; // TODO: smax, umax, smin, umax, umin_seq.
9535 llvm_unreachable("Unknown SCEV kind!");
9538 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9539 if (isa<SCEVConstant>(V)) return V;
9541 // If this instruction is evolved from a constant-evolving PHI, compute the
9542 // exit value from the loop without using SCEVs.
9543 if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
9544 if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
9545 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9546 const Loop *CurrLoop = this->LI[I->getParent()];
9547 // Looking for loop exit value.
9548 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9549 PN->getParent() == CurrLoop->getHeader()) {
9550 // Okay, there is no closed form solution for the PHI node. Check
9551 // to see if the loop that contains it has a known backedge-taken
9552 // count. If so, we may be able to force computation of the exit
9554 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9555 // This trivial case can show up in some degenerate cases where
9556 // the incoming IR has not yet been fully simplified.
9557 if (BackedgeTakenCount->isZero()) {
9558 Value *InitValue = nullptr;
9559 bool MultipleInitValues = false;
9560 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9561 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9563 InitValue = PN->getIncomingValue(i);
9564 else if (InitValue != PN->getIncomingValue(i)) {
9565 MultipleInitValues = true;
9570 if (!MultipleInitValues && InitValue)
9571 return getSCEV(InitValue);
9573 // Do we have a loop invariant value flowing around the backedge
9574 // for a loop which must execute the backedge?
9575 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
9576 isKnownPositive(BackedgeTakenCount) &&
9577 PN->getNumIncomingValues() == 2) {
9579 unsigned InLoopPred =
9580 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
9581 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
9582 if (CurrLoop->isLoopInvariant(BackedgeVal))
9583 return getSCEV(BackedgeVal);
9585 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
9586 // Okay, we know how many times the containing loop executes. If
9587 // this is a constant evolving PHI node, get the final value at
9588 // the specified iteration number.
9589 Constant *RV = getConstantEvolutionLoopExitValue(
9590 PN, BTCC->getAPInt(), CurrLoop);
9591 if (RV) return getSCEV(RV);
9595 // If there is a single-input Phi, evaluate it at our scope. If we can
9596 // prove that this replacement does not break LCSSA form, use new value.
9597 if (PN->getNumOperands() == 1) {
9598 const SCEV *Input = getSCEV(PN->getOperand(0));
9599 const SCEV *InputAtScope = getSCEVAtScope(Input, L);
9600 // TODO: We can generalize it using LI.replacementPreservesLCSSAForm,
9601 // for the simplest case just support constants.
9602 if (isa<SCEVConstant>(InputAtScope)) return InputAtScope;
9606 // Okay, this is an expression that we cannot symbolically evaluate
9607 // into a SCEV. Check to see if it's possible to symbolically evaluate
9608 // the arguments into constants, and if so, try to constant propagate the
9609 // result. This is particularly useful for computing loop exit values.
9610 if (CanConstantFold(I)) {
9611 SmallVector<Constant *, 4> Operands;
9612 bool MadeImprovement = false;
9613 for (Value *Op : I->operands()) {
9614 if (Constant *C = dyn_cast<Constant>(Op)) {
9615 Operands.push_back(C);
9619 // If any of the operands is non-constant and if they are
9620 // non-integer and non-pointer, don't even try to analyze them
9621 // with scev techniques.
9622 if (!isSCEVable(Op->getType()))
9625 const SCEV *OrigV = getSCEV(Op);
9626 const SCEV *OpV = getSCEVAtScope(OrigV, L);
9627 MadeImprovement |= OrigV != OpV;
9629 Constant *C = BuildConstantFromSCEV(OpV);
9631 if (C->getType() != Op->getType())
9632 C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
9636 Operands.push_back(C);
9639 // Check to see if getSCEVAtScope actually made an improvement.
9640 if (MadeImprovement) {
9641 Constant *C = nullptr;
9642 const DataLayout &DL = getDataLayout();
9643 C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
9650 // This is some other type of SCEVUnknown, just return it.
9654 if (isa<SCEVCommutativeExpr>(V) || isa<SCEVSequentialMinMaxExpr>(V)) {
9655 const auto *Comm = cast<SCEVNAryExpr>(V);
9656 // Avoid performing the look-up in the common case where the specified
9657 // expression has no loop-variant portions.
9658 for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
9659 const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
9660 if (OpAtScope != Comm->getOperand(i)) {
9661 // Okay, at least one of these operands is loop variant but might be
9662 // foldable. Build a new instance of the folded commutative expression.
9663 SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
9664 Comm->op_begin()+i);
9665 NewOps.push_back(OpAtScope);
9667 for (++i; i != e; ++i) {
9668 OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
9669 NewOps.push_back(OpAtScope);
9671 if (isa<SCEVAddExpr>(Comm))
9672 return getAddExpr(NewOps, Comm->getNoWrapFlags());
9673 if (isa<SCEVMulExpr>(Comm))
9674 return getMulExpr(NewOps, Comm->getNoWrapFlags());
9675 if (isa<SCEVMinMaxExpr>(Comm))
9676 return getMinMaxExpr(Comm->getSCEVType(), NewOps);
9677 if (isa<SCEVSequentialMinMaxExpr>(Comm))
9678 return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
9679 llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
9682 // If we got here, all operands are loop invariant.
9686 if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
9687 const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
9688 const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
9689 if (LHS == Div->getLHS() && RHS == Div->getRHS())
9690 return Div; // must be loop invariant
9691 return getUDivExpr(LHS, RHS);
9694 // If this is a loop recurrence for a loop that does not contain L, then we
9695 // are dealing with the final value computed by the loop.
9696 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
9697 // First, attempt to evaluate each operand.
9698 // Avoid performing the look-up in the common case where the specified
9699 // expression has no loop-variant portions.
9700 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9701 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9702 if (OpAtScope == AddRec->getOperand(i))
9705 // Okay, at least one of these operands is loop variant but might be
9706 // foldable. Build a new instance of the folded commutative expression.
9707 SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
9708 AddRec->op_begin()+i);
9709 NewOps.push_back(OpAtScope);
9710 for (++i; i != e; ++i)
9711 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9713 const SCEV *FoldedRec =
9714 getAddRecExpr(NewOps, AddRec->getLoop(),
9715 AddRec->getNoWrapFlags(SCEV::FlagNW));
9716 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9717 // The addrec may be folded to a nonrecurrence, for example, if the
9718 // induction variable is multiplied by zero after constant folding. Go
9719 // ahead and return the folded value.
9725 // If the scope is outside the addrec's loop, evaluate it by using the
9726 // loop exit value of the addrec.
9727 if (!AddRec->getLoop()->contains(L)) {
9728 // To evaluate this recurrence, we need to know how many times the AddRec
9729 // loop iterates. Compute this now.
9730 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9731 if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
9733 // Then, evaluate the AddRec.
9734 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9740 if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
9741 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
9742 if (Op == Cast->getOperand())
9743 return Cast; // must be loop invariant
9744 return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
9747 llvm_unreachable("Unknown SCEV type!");
9750 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
9751 return getSCEVAtScope(getSCEV(V), L);
9754 const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
9755 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
9756 return stripInjectiveFunctions(ZExt->getOperand());
9757 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
9758 return stripInjectiveFunctions(SExt->getOperand());
9762 /// Finds the minimum unsigned root of the following equation:
9764 /// A * X = B (mod N)
9766 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
9767 /// A and B isn't important.
9769 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
9770 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
9771 ScalarEvolution &SE) {
9772 uint32_t BW = A.getBitWidth();
9773 assert(BW == SE.getTypeSizeInBits(B->getType()));
9774 assert(A != 0 && "A must be non-zero.");
9778 // The gcd of A and N may have only one prime factor: 2. The number of
9779 // trailing zeros in A is its multiplicity
9780 uint32_t Mult2 = A.countTrailingZeros();
9783 // 2. Check if B is divisible by D.
9785 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
9786 // is not less than multiplicity of this prime factor for D.
9787 if (SE.GetMinTrailingZeros(B) < Mult2)
9788 return SE.getCouldNotCompute();
9790 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
9793 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
9794 // (N / D) in general. The inverse itself always fits into BW bits, though,
9795 // so we immediately truncate it.
9796 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
9797 APInt Mod(BW + 1, 0);
9798 Mod.setBit(BW - Mult2); // Mod = N / D
9799 APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
9801 // 4. Compute the minimum unsigned root of the equation:
9802 // I * (B / D) mod (N / D)
9803 // To simplify the computation, we factor out the divide by D:
9804 // (I * B mod N) / D
9805 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
9806 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
9809 /// For a given quadratic addrec, generate coefficients of the corresponding
9810 /// quadratic equation, multiplied by a common value to ensure that they are
9812 /// The returned value is a tuple { A, B, C, M, BitWidth }, where
9813 /// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
9814 /// were multiplied by, and BitWidth is the bit width of the original addrec
9816 /// This function returns None if the addrec coefficients are not compile-
9818 static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
9819 GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
9820 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
9821 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
9822 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
9823 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
9824 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
9825 << *AddRec << '\n');
9827 // We currently can only solve this if the coefficients are constants.
9828 if (!LC || !MC || !NC) {
9829 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
9833 APInt L = LC->getAPInt();
9834 APInt M = MC->getAPInt();
9835 APInt N = NC->getAPInt();
9836 assert(!N.isZero() && "This is not a quadratic addrec");
9838 unsigned BitWidth = LC->getAPInt().getBitWidth();
9839 unsigned NewWidth = BitWidth + 1;
9840 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
9841 << BitWidth << '\n');
9842 // The sign-extension (as opposed to a zero-extension) here matches the
9843 // extension used in SolveQuadraticEquationWrap (with the same motivation).
9844 N = N.sext(NewWidth);
9845 M = M.sext(NewWidth);
9846 L = L.sext(NewWidth);
9848 // The increments are M, M+N, M+2N, ..., so the accumulated values are
9849 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
9850 // L+M, L+2M+N, L+3M+3N, ...
9851 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
9853 // The equation Acc = 0 is then
9854 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
9855 // In a quadratic form it becomes:
9856 // N n^2 + (2M-N) n + 2L = 0.
9859 APInt B = 2 * M - A;
9861 APInt T = APInt(NewWidth, 2);
9862 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
9863 << "x + " << C << ", coeff bw: " << NewWidth
9864 << ", multiplied by " << T << '\n');
9865 return std::make_tuple(A, B, C, T, BitWidth);
9868 /// Helper function to compare optional APInts:
9869 /// (a) if X and Y both exist, return min(X, Y),
9870 /// (b) if neither X nor Y exist, return None,
9871 /// (c) if exactly one of X and Y exists, return that value.
9872 static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) {
9874 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
9875 APInt XW = X->sext(W);
9876 APInt YW = Y->sext(W);
9877 return XW.slt(YW) ? *X : *Y;
9884 /// Helper function to truncate an optional APInt to a given BitWidth.
9885 /// When solving addrec-related equations, it is preferable to return a value
9886 /// that has the same bit width as the original addrec's coefficients. If the
9887 /// solution fits in the original bit width, truncate it (except for i1).
9888 /// Returning a value of a different bit width may inhibit some optimizations.
9890 /// In general, a solution to a quadratic equation generated from an addrec
9891 /// may require BW+1 bits, where BW is the bit width of the addrec's
9892 /// coefficients. The reason is that the coefficients of the quadratic
9893 /// equation are BW+1 bits wide (to avoid truncation when converting from
9894 /// the addrec to the equation).
9895 static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) {
9898 unsigned W = X->getBitWidth();
9899 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
9900 return X->trunc(BitWidth);
9904 /// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
9905 /// iterations. The values L, M, N are assumed to be signed, and they
9906 /// should all have the same bit widths.
9907 /// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
9908 /// where BW is the bit width of the addrec's coefficients.
9909 /// If the calculated value is a BW-bit integer (for BW > 1), it will be
9910 /// returned as such, otherwise the bit width of the returned value may
9911 /// be greater than BW.
9913 /// This function returns None if
9914 /// (a) the addrec coefficients are not constant, or
9915 /// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
9916 /// like x^2 = 5, no integer solutions exist, in other cases an integer
9917 /// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
9918 static Optional<APInt>
9919 SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
9922 auto T = GetQuadraticEquation(AddRec);
9926 std::tie(A, B, C, M, BitWidth) = *T;
9927 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
9928 Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1);
9932 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
9933 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
9937 return TruncIfPossible(X, BitWidth);
9940 /// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
9941 /// iterations. The values M, N are assumed to be signed, and they
9942 /// should all have the same bit widths.
9943 /// Find the least n such that c(n) does not belong to the given range,
9944 /// while c(n-1) does.
9946 /// This function returns None if
9947 /// (a) the addrec coefficients are not constant, or
9948 /// (b) SolveQuadraticEquationWrap was unable to find a solution for the
9949 /// bounds of the range.
9950 static Optional<APInt>
9951 SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
9952 const ConstantRange &Range, ScalarEvolution &SE) {
9953 assert(AddRec->getOperand(0)->isZero() &&
9954 "Starting value of addrec should be 0");
9955 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
9956 << Range << ", addrec " << *AddRec << '\n');
9957 // This case is handled in getNumIterationsInRange. Here we can assume that
9958 // we start in the range.
9959 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
9960 "Addrec's initial value should be in range");
9964 auto T = GetQuadraticEquation(AddRec);
9968 // Be careful about the return value: there can be two reasons for not
9969 // returning an actual number. First, if no solutions to the equations
9970 // were found, and second, if the solutions don't leave the given range.
9971 // The first case means that the actual solution is "unknown", the second
9972 // means that it's known, but not valid. If the solution is unknown, we
9973 // cannot make any conclusions.
9974 // Return a pair: the optional solution and a flag indicating if the
9975 // solution was found.
9976 auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> {
9977 // Solve for signed overflow and unsigned overflow, pick the lower
9979 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
9980 << Bound << " (before multiplying by " << M << ")\n");
9981 Bound *= M; // The quadratic equation multiplier.
9983 Optional<APInt> SO = None;
9985 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
9986 "signed overflow\n");
9987 SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
9989 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
9990 "unsigned overflow\n");
9991 Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound,
9994 auto LeavesRange = [&] (const APInt &X) {
9995 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
9996 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
9997 if (Range.contains(V0->getValue()))
9999 // X should be at least 1, so X-1 is non-negative.
10000 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10001 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10002 if (Range.contains(V1->getValue()))
10007 // If SolveQuadraticEquationWrap returns None, it means that there can
10008 // be a solution, but the function failed to find it. We cannot treat it
10009 // as "no solution".
10011 return { None, false };
10013 // Check the smaller value first to see if it leaves the range.
10014 // At this point, both SO and UO must have values.
10015 Optional<APInt> Min = MinOptional(SO, UO);
10016 if (LeavesRange(*Min))
10017 return { Min, true };
10018 Optional<APInt> Max = Min == SO ? UO : SO;
10019 if (LeavesRange(*Max))
10020 return { Max, true };
10022 // Solutions were found, but were eliminated, hence the "true".
10023 return { None, true };
10026 std::tie(A, B, C, M, BitWidth) = *T;
10027 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10028 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10029 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10030 auto SL = SolveForBoundary(Lower);
10031 auto SU = SolveForBoundary(Upper);
10032 // If any of the solutions was unknown, no meaninigful conclusions can
10034 if (!SL.second || !SU.second)
10037 // Claim: The correct solution is not some value between Min and Max.
10039 // Justification: Assuming that Min and Max are different values, one of
10040 // them is when the first signed overflow happens, the other is when the
10041 // first unsigned overflow happens. Crossing the range boundary is only
10042 // possible via an overflow (treating 0 as a special case of it, modeling
10043 // an overflow as crossing k*2^W for some k).
10045 // The interesting case here is when Min was eliminated as an invalid
10046 // solution, but Max was not. The argument is that if there was another
10047 // overflow between Min and Max, it would also have been eliminated if
10048 // it was considered.
10050 // For a given boundary, it is possible to have two overflows of the same
10051 // type (signed/unsigned) without having the other type in between: this
10052 // can happen when the vertex of the parabola is between the iterations
10053 // corresponding to the overflows. This is only possible when the two
10054 // overflows cross k*2^W for the same k. In such case, if the second one
10055 // left the range (and was the first one to do so), the first overflow
10056 // would have to enter the range, which would mean that either we had left
10057 // the range before or that we started outside of it. Both of these cases
10058 // are contradictions.
10060 // Claim: In the case where SolveForBoundary returns None, the correct
10061 // solution is not some value between the Max for this boundary and the
10062 // Min of the other boundary.
10064 // Justification: Assume that we had such Max_A and Min_B corresponding
10065 // to range boundaries A and B and such that Max_A < Min_B. If there was
10066 // a solution between Max_A and Min_B, it would have to be caused by an
10067 // overflow corresponding to either A or B. It cannot correspond to B,
10068 // since Min_B is the first occurrence of such an overflow. If it
10069 // corresponded to A, it would have to be either a signed or an unsigned
10070 // overflow that is larger than both eliminated overflows for A. But
10071 // between the eliminated overflows and this overflow, the values would
10072 // cover the entire value space, thus crossing the other boundary, which
10073 // is a contradiction.
10075 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10078 ScalarEvolution::ExitLimit
10079 ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
10080 bool AllowPredicates) {
10082 // This is only used for loops with a "x != y" exit test. The exit condition
10083 // is now expressed as a single expression, V = x-y. So the exit test is
10084 // effectively V != 0. We know and take advantage of the fact that this
10085 // expression only being used in a comparison by zero context.
10087 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
10088 // If the value is a constant
10089 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10090 // If the value is already zero, the branch will execute zero times.
10091 if (C->getValue()->isZero()) return C;
10092 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10095 const SCEVAddRecExpr *AddRec =
10096 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10098 if (!AddRec && AllowPredicates)
10099 // Try to make this an AddRec using runtime tests, in the first X
10100 // iterations of this loop, where X is the SCEV expression found by the
10101 // algorithm below.
10102 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10104 if (!AddRec || AddRec->getLoop() != L)
10105 return getCouldNotCompute();
10107 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10108 // the quadratic equation to solve it.
10109 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10110 // We can only use this value if the chrec ends up with an exact zero
10111 // value at this index. When solving for "X*X != 5", for example, we
10112 // should not accept a root of 2.
10113 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10114 const auto *R = cast<SCEVConstant>(getConstant(*S));
10115 return ExitLimit(R, R, false, Predicates);
10117 return getCouldNotCompute();
10120 // Otherwise we can only handle this if it is affine.
10121 if (!AddRec->isAffine())
10122 return getCouldNotCompute();
10124 // If this is an affine expression, the execution count of this branch is
10125 // the minimum unsigned root of the following equation:
10127 // Start + Step*N = 0 (mod 2^BW)
10131 // Step*N = -Start (mod 2^BW)
10133 // where BW is the common bit width of Start and Step.
10135 // Get the initial value for the loop.
10136 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10137 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10139 // For now we handle only constant steps.
10141 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
10142 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
10143 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
10144 // We have not yet seen any such cases.
10145 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10146 if (!StepC || StepC->getValue()->isZero())
10147 return getCouldNotCompute();
10149 // For positive steps (counting up until unsigned overflow):
10150 // N = -Start/Step (as unsigned)
10151 // For negative steps (counting down to zero):
10153 // First compute the unsigned distance from zero in the direction of Step.
10154 bool CountDown = StepC->getAPInt().isNegative();
10155 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10157 // Handle unitary steps, which cannot wraparound.
10158 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10159 // N = Distance (as unsigned)
10160 if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
10161 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10162 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10164 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10165 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10166 // case, and see if we can improve the bound.
10168 // Explicitly handling this here is necessary because getUnsignedRange
10169 // isn't context-sensitive; it doesn't know that we only care about the
10170 // range inside the loop.
10171 const SCEV *Zero = getZero(Distance->getType());
10172 const SCEV *One = getOne(Distance->getType());
10173 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10174 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10175 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10176 // as "unsigned_max(Distance + 1) - 1".
10177 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10178 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10180 return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
10183 // If the condition controls loop exit (the loop exits only if the expression
10184 // is true) and the addition is no-wrap we can use unsigned divide to
10185 // compute the backedge count. In this case, the step may not divide the
10186 // distance, but we don't care because if the condition is "missed" the loop
10187 // will have undefined behavior due to wrapping.
10188 if (ControlsExit && AddRec->hasNoSelfWrap() &&
10189 loopHasNoAbnormalExits(AddRec->getLoop())) {
10190 const SCEV *Exact =
10191 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10192 const SCEV *Max = getCouldNotCompute();
10193 if (Exact != getCouldNotCompute()) {
10194 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
10195 Max = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
10197 return ExitLimit(Exact, Max, false, Predicates);
10200 // Solve the general equation.
10201 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10202 getNegativeSCEV(Start), *this);
10205 if (E != getCouldNotCompute()) {
10206 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10207 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10209 return ExitLimit(E, M, false, Predicates);
10212 ScalarEvolution::ExitLimit
10213 ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10214 // Loops that look like: while (X == 0) are very strange indeed. We don't
10215 // handle them yet except for the trivial case. This could be expanded in the
10216 // future as needed.
10218 // If the value is a constant, check to see if it is known to be non-zero
10219 // already. If so, the backedge will execute zero times.
10220 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10221 if (!C->getValue()->isZero())
10222 return getZero(C->getType());
10223 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10226 // We could implement others, but I really doubt anyone writes loops like
10227 // this, and if they did, they would already be constant folded.
10228 return getCouldNotCompute();
10231 std::pair<const BasicBlock *, const BasicBlock *>
10232 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10234 // If the block has a unique predecessor, then there is no path from the
10235 // predecessor to the block that does not go through the direct edge
10236 // from the predecessor to the block.
10237 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10240 // A loop's header is defined to be a block that dominates the loop.
10241 // If the header has a unique predecessor outside the loop, it must be
10242 // a block that has exactly one successor that can reach the loop.
10243 if (const Loop *L = LI.getLoopFor(BB))
10244 return {L->getLoopPredecessor(), L->getHeader()};
10246 return {nullptr, nullptr};
10249 /// SCEV structural equivalence is usually sufficient for testing whether two
10250 /// expressions are equal, however for the purposes of looking for a condition
10251 /// guarding a loop, it can be useful to be a little more general, since a
10252 /// front-end may have replicated the controlling expression.
10253 static bool HasSameValue(const SCEV *A, const SCEV *B) {
10254 // Quick check to see if they are the same SCEV.
10255 if (A == B) return true;
10257 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10258 // Not all instructions that are "identical" compute the same value. For
10259 // instance, two distinct alloca instructions allocating the same type are
10260 // identical and do not read memory; but compute distinct values.
10261 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10264 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10265 // two different instructions with the same value. Check for this case.
10266 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10267 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10268 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10269 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10270 if (ComputesEqualValues(AI, BI))
10273 // Otherwise assume they may have a different value.
10277 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
10278 const SCEV *&LHS, const SCEV *&RHS,
10280 bool ControllingFiniteLoop) {
10281 bool Changed = false;
10282 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10284 auto TrivialCase = [&](bool TriviallyTrue) {
10285 LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
10286 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10289 // If we hit the max recursion limit bail out.
10293 // Canonicalize a constant to the right side.
10294 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10295 // Check for both operands constant.
10296 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10297 if (ConstantExpr::getICmp(Pred,
10299 RHSC->getValue())->isNullValue())
10300 return TrivialCase(false);
10302 return TrivialCase(true);
10304 // Otherwise swap the operands to put the constant on the right.
10305 std::swap(LHS, RHS);
10306 Pred = ICmpInst::getSwappedPredicate(Pred);
10310 // If we're comparing an addrec with a value which is loop-invariant in the
10311 // addrec's loop, put the addrec on the left. Also make a dominance check,
10312 // as both operands could be addrecs loop-invariant in each other's loop.
10313 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10314 const Loop *L = AR->getLoop();
10315 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10316 std::swap(LHS, RHS);
10317 Pred = ICmpInst::getSwappedPredicate(Pred);
10322 // If there's a constant operand, canonicalize comparisons with boundary
10323 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10324 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10325 const APInt &RA = RC->getAPInt();
10327 bool SimplifiedByConstantRange = false;
10329 if (!ICmpInst::isEquality(Pred)) {
10330 ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
10331 if (ExactCR.isFullSet())
10332 return TrivialCase(true);
10333 else if (ExactCR.isEmptySet())
10334 return TrivialCase(false);
10337 CmpInst::Predicate NewPred;
10338 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10339 ICmpInst::isEquality(NewPred)) {
10340 // We were able to convert an inequality to an equality.
10342 RHS = getConstant(NewRHS);
10343 Changed = SimplifiedByConstantRange = true;
10347 if (!SimplifiedByConstantRange) {
10351 case ICmpInst::ICMP_EQ:
10352 case ICmpInst::ICMP_NE:
10353 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10355 if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
10356 if (const SCEVMulExpr *ME =
10357 dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
10358 if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
10359 ME->getOperand(0)->isAllOnesValue()) {
10360 RHS = AE->getOperand(1);
10361 LHS = ME->getOperand(1);
10367 // The "Should have been caught earlier!" messages refer to the fact
10368 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10369 // should have fired on the corresponding cases, and canonicalized the
10370 // check to trivial case.
10372 case ICmpInst::ICMP_UGE:
10373 assert(!RA.isMinValue() && "Should have been caught earlier!");
10374 Pred = ICmpInst::ICMP_UGT;
10375 RHS = getConstant(RA - 1);
10378 case ICmpInst::ICMP_ULE:
10379 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10380 Pred = ICmpInst::ICMP_ULT;
10381 RHS = getConstant(RA + 1);
10384 case ICmpInst::ICMP_SGE:
10385 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10386 Pred = ICmpInst::ICMP_SGT;
10387 RHS = getConstant(RA - 1);
10390 case ICmpInst::ICMP_SLE:
10391 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10392 Pred = ICmpInst::ICMP_SLT;
10393 RHS = getConstant(RA + 1);
10400 // Check for obvious equality.
10401 if (HasSameValue(LHS, RHS)) {
10402 if (ICmpInst::isTrueWhenEqual(Pred))
10403 return TrivialCase(true);
10404 if (ICmpInst::isFalseWhenEqual(Pred))
10405 return TrivialCase(false);
10408 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10409 // adding or subtracting 1 from one of the operands. This can be done for
10410 // one of two reasons:
10411 // 1) The range of the RHS does not include the (signed/unsigned) boundaries
10412 // 2) The loop is finite, with this comparison controlling the exit. Since the
10413 // loop is finite, the bound cannot include the corresponding boundary
10414 // (otherwise it would loop forever).
10416 case ICmpInst::ICMP_SLE:
10417 if (ControllingFiniteLoop || !getSignedRangeMax(RHS).isMaxSignedValue()) {
10418 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10420 Pred = ICmpInst::ICMP_SLT;
10422 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10423 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10425 Pred = ICmpInst::ICMP_SLT;
10429 case ICmpInst::ICMP_SGE:
10430 if (ControllingFiniteLoop || !getSignedRangeMin(RHS).isMinSignedValue()) {
10431 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10433 Pred = ICmpInst::ICMP_SGT;
10435 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10436 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10438 Pred = ICmpInst::ICMP_SGT;
10442 case ICmpInst::ICMP_ULE:
10443 if (ControllingFiniteLoop || !getUnsignedRangeMax(RHS).isMaxValue()) {
10444 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10446 Pred = ICmpInst::ICMP_ULT;
10448 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10449 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10450 Pred = ICmpInst::ICMP_ULT;
10454 case ICmpInst::ICMP_UGE:
10455 if (ControllingFiniteLoop || !getUnsignedRangeMin(RHS).isMinValue()) {
10456 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10457 Pred = ICmpInst::ICMP_UGT;
10459 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10460 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10462 Pred = ICmpInst::ICMP_UGT;
10470 // TODO: More simplifications are possible here.
10472 // Recursively simplify until we either hit a recursion limit or nothing
10475 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1,
10476 ControllingFiniteLoop);
10481 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
10482 return getSignedRangeMax(S).isNegative();
10485 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
10486 return getSignedRangeMin(S).isStrictlyPositive();
10489 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
10490 return !getSignedRangeMin(S).isNegative();
10493 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
10494 return !getSignedRangeMax(S).isStrictlyPositive();
10497 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
10498 return getUnsignedRangeMin(S) != 0;
10501 std::pair<const SCEV *, const SCEV *>
10502 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
10503 // Compute SCEV on entry of loop L.
10504 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10505 if (Start == getCouldNotCompute())
10506 return { Start, Start };
10507 // Compute post increment SCEV for loop L.
10508 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10509 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10510 return { Start, PostInc };
10513 bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
10514 const SCEV *LHS, const SCEV *RHS) {
10515 // First collect all loops.
10516 SmallPtrSet<const Loop *, 8> LoopsUsed;
10517 getUsedLoops(LHS, LoopsUsed);
10518 getUsedLoops(RHS, LoopsUsed);
10520 if (LoopsUsed.empty())
10523 // Domination relationship must be a linear order on collected loops.
10525 for (auto *L1 : LoopsUsed)
10526 for (auto *L2 : LoopsUsed)
10527 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10528 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10529 "Domination relationship is not a linear order");
10533 *std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
10534 [&](const Loop *L1, const Loop *L2) {
10535 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10538 // Get init and post increment value for LHS.
10539 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10540 // if LHS contains unknown non-invariant SCEV then bail out.
10541 if (SplitLHS.first == getCouldNotCompute())
10543 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10544 // Get init and post increment value for RHS.
10545 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10546 // if RHS contains unknown non-invariant SCEV then bail out.
10547 if (SplitRHS.first == getCouldNotCompute())
10549 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10550 // It is possible that init SCEV contains an invariant load but it does
10551 // not dominate MDL and is not available at MDL loop entry, so we should
10553 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10554 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10557 // It seems backedge guard check is faster than entry one so in some cases
10558 // it can speed up whole estimation by short circuit
10559 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10560 SplitRHS.second) &&
10561 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10564 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
10565 const SCEV *LHS, const SCEV *RHS) {
10566 // Canonicalize the inputs first.
10567 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10569 if (isKnownViaInduction(Pred, LHS, RHS))
10572 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10575 // Otherwise see what can be done with some simple reasoning.
10576 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10579 Optional<bool> ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred,
10582 if (isKnownPredicate(Pred, LHS, RHS))
10584 else if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS))
10589 bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred,
10590 const SCEV *LHS, const SCEV *RHS,
10591 const Instruction *CtxI) {
10592 // TODO: Analyze guards and assumes from Context's block.
10593 return isKnownPredicate(Pred, LHS, RHS) ||
10594 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
10597 Optional<bool> ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred,
10600 const Instruction *CtxI) {
10601 Optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10602 if (KnownWithoutContext)
10603 return KnownWithoutContext;
10605 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10607 else if (isBasicBlockEntryGuardedByCond(CtxI->getParent(),
10608 ICmpInst::getInversePredicate(Pred),
10614 bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
10615 const SCEVAddRecExpr *LHS,
10617 const Loop *L = LHS->getLoop();
10618 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10619 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10622 Optional<ScalarEvolution::MonotonicPredicateType>
10623 ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS,
10624 ICmpInst::Predicate Pred) {
10625 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10628 // Verify an invariant: inverting the predicate should turn a monotonically
10629 // increasing change to a monotonically decreasing one, and vice versa.
10631 auto ResultSwapped =
10632 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10634 assert(ResultSwapped && "should be able to analyze both!");
10635 assert(ResultSwapped.getValue() != Result.getValue() &&
10636 "monotonicity should flip as we flip the predicate");
10643 Optional<ScalarEvolution::MonotonicPredicateType>
10644 ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
10645 ICmpInst::Predicate Pred) {
10646 // A zero step value for LHS means the induction variable is essentially a
10647 // loop invariant value. We don't really depend on the predicate actually
10648 // flipping from false to true (for increasing predicates, and the other way
10649 // around for decreasing predicates), all we care about is that *if* the
10650 // predicate changes then it only changes from false to true.
10652 // A zero step value in itself is not very useful, but there may be places
10653 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
10654 // as general as possible.
10656 // Only handle LE/LT/GE/GT predicates.
10657 if (!ICmpInst::isRelational(Pred))
10660 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
10661 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
10662 "Should be greater or less!");
10664 // Check that AR does not wrap.
10665 if (ICmpInst::isUnsigned(Pred)) {
10666 if (!LHS->hasNoUnsignedWrap())
10668 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10670 assert(ICmpInst::isSigned(Pred) &&
10671 "Relational predicate is either signed or unsigned!");
10672 if (!LHS->hasNoSignedWrap())
10675 const SCEV *Step = LHS->getStepRecurrence(*this);
10677 if (isKnownNonNegative(Step))
10678 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10680 if (isKnownNonPositive(Step))
10681 return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10687 Optional<ScalarEvolution::LoopInvariantPredicate>
10688 ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred,
10689 const SCEV *LHS, const SCEV *RHS,
10692 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10693 if (!isLoopInvariant(RHS, L)) {
10694 if (!isLoopInvariant(LHS, L))
10697 std::swap(LHS, RHS);
10698 Pred = ICmpInst::getSwappedPredicate(Pred);
10701 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
10702 if (!ArLHS || ArLHS->getLoop() != L)
10705 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
10706 if (!MonotonicType)
10708 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
10709 // true as the loop iterates, and the backedge is control dependent on
10710 // "ArLHS `Pred` RHS" == true then we can reason as follows:
10712 // * if the predicate was false in the first iteration then the predicate
10713 // is never evaluated again, since the loop exits without taking the
10715 // * if the predicate was true in the first iteration then it will
10716 // continue to be true for all future iterations since it is
10717 // monotonically increasing.
10719 // For both the above possibilities, we can replace the loop varying
10720 // predicate with its value on the first iteration of the loop (which is
10721 // loop invariant).
10723 // A similar reasoning applies for a monotonically decreasing predicate, by
10724 // replacing true with false and false with true in the above two bullets.
10725 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
10726 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
10728 if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
10731 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), RHS);
10734 Optional<ScalarEvolution::LoopInvariantPredicate>
10735 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
10736 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
10737 const Instruction *CtxI, const SCEV *MaxIter) {
10738 // Try to prove the following set of facts:
10739 // - The predicate is monotonic in the iteration space.
10740 // - If the check does not fail on the 1st iteration:
10741 // - No overflow will happen during first MaxIter iterations;
10742 // - It will not fail on the MaxIter'th iteration.
10743 // If the check does fail on the 1st iteration, we leave the loop and no
10744 // other checks matter.
10746 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10747 if (!isLoopInvariant(RHS, L)) {
10748 if (!isLoopInvariant(LHS, L))
10751 std::swap(LHS, RHS);
10752 Pred = ICmpInst::getSwappedPredicate(Pred);
10755 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
10756 if (!AR || AR->getLoop() != L)
10759 // The predicate must be relational (i.e. <, <=, >=, >).
10760 if (!ICmpInst::isRelational(Pred))
10763 // TODO: Support steps other than +/- 1.
10764 const SCEV *Step = AR->getStepRecurrence(*this);
10765 auto *One = getOne(Step->getType());
10766 auto *MinusOne = getNegativeSCEV(One);
10767 if (Step != One && Step != MinusOne)
10770 // Type mismatch here means that MaxIter is potentially larger than max
10771 // unsigned value in start type, which mean we cannot prove no wrap for the
10773 if (AR->getType() != MaxIter->getType())
10776 // Value of IV on suggested last iteration.
10777 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
10778 // Does it still meet the requirement?
10779 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
10781 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
10782 // not exceed max unsigned value of this type), this effectively proves
10783 // that there is no wrap during the iteration. To prove that there is no
10784 // signed/unsigned wrap, we need to check that
10785 // Start <= Last for step = 1 or Start >= Last for step = -1.
10786 ICmpInst::Predicate NoOverflowPred =
10787 CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
10788 if (Step == MinusOne)
10789 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
10790 const SCEV *Start = AR->getStart();
10791 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
10794 // Everything is fine.
10795 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
10798 bool ScalarEvolution::isKnownPredicateViaConstantRanges(
10799 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
10800 if (HasSameValue(LHS, RHS))
10801 return ICmpInst::isTrueWhenEqual(Pred);
10803 // This code is split out from isKnownPredicate because it is called from
10804 // within isLoopEntryGuardedByCond.
10806 auto CheckRanges = [&](const ConstantRange &RangeLHS,
10807 const ConstantRange &RangeRHS) {
10808 return RangeLHS.icmp(Pred, RangeRHS);
10811 // The check at the top of the function catches the case where the values are
10812 // known to be equal.
10813 if (Pred == CmpInst::ICMP_EQ)
10816 if (Pred == CmpInst::ICMP_NE) {
10817 auto SL = getSignedRange(LHS);
10818 auto SR = getSignedRange(RHS);
10819 if (CheckRanges(SL, SR))
10821 auto UL = getUnsignedRange(LHS);
10822 auto UR = getUnsignedRange(RHS);
10823 if (CheckRanges(UL, UR))
10825 auto *Diff = getMinusSCEV(LHS, RHS);
10826 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
10829 if (CmpInst::isSigned(Pred)) {
10830 auto SL = getSignedRange(LHS);
10831 auto SR = getSignedRange(RHS);
10832 return CheckRanges(SL, SR);
10835 auto UL = getUnsignedRange(LHS);
10836 auto UR = getUnsignedRange(RHS);
10837 return CheckRanges(UL, UR);
10840 bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
10843 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
10844 // C1 and C2 are constant integers. If either X or Y are not add expressions,
10845 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
10846 // OutC1 and OutC2.
10847 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
10848 APInt &OutC1, APInt &OutC2,
10849 SCEV::NoWrapFlags ExpectedFlags) {
10850 const SCEV *XNonConstOp, *XConstOp;
10851 const SCEV *YNonConstOp, *YConstOp;
10852 SCEV::NoWrapFlags XFlagsPresent;
10853 SCEV::NoWrapFlags YFlagsPresent;
10855 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
10856 XConstOp = getZero(X->getType());
10858 XFlagsPresent = ExpectedFlags;
10860 if (!isa<SCEVConstant>(XConstOp) ||
10861 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
10864 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
10865 YConstOp = getZero(Y->getType());
10867 YFlagsPresent = ExpectedFlags;
10870 if (!isa<SCEVConstant>(YConstOp) ||
10871 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
10874 if (YNonConstOp != XNonConstOp)
10877 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
10878 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
10890 case ICmpInst::ICMP_SGE:
10891 std::swap(LHS, RHS);
10893 case ICmpInst::ICMP_SLE:
10894 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
10895 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
10900 case ICmpInst::ICMP_SGT:
10901 std::swap(LHS, RHS);
10903 case ICmpInst::ICMP_SLT:
10904 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
10905 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
10910 case ICmpInst::ICMP_UGE:
10911 std::swap(LHS, RHS);
10913 case ICmpInst::ICMP_ULE:
10914 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
10915 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
10920 case ICmpInst::ICMP_UGT:
10921 std::swap(LHS, RHS);
10923 case ICmpInst::ICMP_ULT:
10924 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
10925 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
10933 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
10936 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
10939 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
10940 // the stack can result in exponential time complexity.
10941 SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
10943 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
10945 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
10946 // isKnownPredicate. isKnownPredicate is more powerful, but also more
10947 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
10948 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
10949 // use isKnownPredicate later if needed.
10950 return isKnownNonNegative(RHS) &&
10951 isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
10952 isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
10955 bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
10956 ICmpInst::Predicate Pred,
10957 const SCEV *LHS, const SCEV *RHS) {
10958 // No need to even try if we know the module has no guards.
10962 return any_of(*BB, [&](const Instruction &I) {
10963 using namespace llvm::PatternMatch;
10966 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
10967 m_Value(Condition))) &&
10968 isImpliedCond(Pred, LHS, RHS, Condition, false);
10972 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
10973 /// protected by a conditional between LHS and RHS. This is used to
10974 /// to eliminate casts.
10976 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
10977 ICmpInst::Predicate Pred,
10978 const SCEV *LHS, const SCEV *RHS) {
10979 // Interpret a null as meaning no loop, where there is obviously no guard
10980 // (interprocedural conditions notwithstanding).
10981 if (!L) return true;
10984 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
10985 "This cannot be done on broken IR!");
10988 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
10991 BasicBlock *Latch = L->getLoopLatch();
10995 BranchInst *LoopContinuePredicate =
10996 dyn_cast<BranchInst>(Latch->getTerminator());
10997 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
10998 isImpliedCond(Pred, LHS, RHS,
10999 LoopContinuePredicate->getCondition(),
11000 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11003 // We don't want more than one activation of the following loops on the stack
11004 // -- that can lead to O(n!) time complexity.
11005 if (WalkingBEDominatingConds)
11008 SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true);
11010 // See if we can exploit a trip count to prove the predicate.
11011 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11012 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11013 if (LatchBECount != getCouldNotCompute()) {
11014 // We know that Latch branches back to the loop header exactly
11015 // LatchBECount times. This means the backdege condition at Latch is
11016 // equivalent to "{0,+,1} u< LatchBECount".
11017 Type *Ty = LatchBECount->getType();
11018 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11019 const SCEV *LoopCounter =
11020 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11021 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11026 // Check conditions due to any @llvm.assume intrinsics.
11027 for (auto &AssumeVH : AC.assumptions()) {
11030 auto *CI = cast<CallInst>(AssumeVH);
11031 if (!DT.dominates(CI, Latch->getTerminator()))
11034 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11038 // If the loop is not reachable from the entry block, we risk running into an
11039 // infinite loop as we walk up into the dom tree. These loops do not matter
11040 // anyway, so we just return a conservative answer when we see them.
11041 if (!DT.isReachableFromEntry(L->getHeader()))
11044 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11047 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11048 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11049 assert(DTN && "should reach the loop header before reaching the root!");
11051 BasicBlock *BB = DTN->getBlock();
11052 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11055 BasicBlock *PBB = BB->getSinglePredecessor();
11059 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11060 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11063 Value *Condition = ContinuePredicate->getCondition();
11065 // If we have an edge `E` within the loop body that dominates the only
11066 // latch, the condition guarding `E` also guards the backedge. This
11067 // reasoning works only for loops with a single latch.
11069 BasicBlockEdge DominatingEdge(PBB, BB);
11070 if (DominatingEdge.isSingleEdge()) {
11071 // We're constructively (and conservatively) enumerating edges within the
11072 // loop body that dominate the latch. The dominator tree better agree
11073 // with us on this:
11074 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11076 if (isImpliedCond(Pred, LHS, RHS, Condition,
11077 BB != ContinuePredicate->getSuccessor(0)))
11085 bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
11086 ICmpInst::Predicate Pred,
11090 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11091 "This cannot be done on broken IR!");
11093 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11094 // the facts (a >= b && a != b) separately. A typical situation is when the
11095 // non-strict comparison is known from ranges and non-equality is known from
11096 // dominating predicates. If we are proving strict comparison, we always try
11097 // to prove non-equality and non-strict comparison separately.
11098 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11099 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11100 bool ProvedNonStrictComparison = false;
11101 bool ProvedNonEquality = false;
11103 auto SplitAndProve =
11104 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11105 if (!ProvedNonStrictComparison)
11106 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11107 if (!ProvedNonEquality)
11108 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11109 if (ProvedNonStrictComparison && ProvedNonEquality)
11114 if (ProvingStrictComparison) {
11115 auto ProofFn = [&](ICmpInst::Predicate P) {
11116 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11118 if (SplitAndProve(ProofFn))
11122 // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard.
11123 auto ProveViaGuard = [&](const BasicBlock *Block) {
11124 if (isImpliedViaGuard(Block, Pred, LHS, RHS))
11126 if (ProvingStrictComparison) {
11127 auto ProofFn = [&](ICmpInst::Predicate P) {
11128 return isImpliedViaGuard(Block, P, LHS, RHS);
11130 if (SplitAndProve(ProofFn))
11136 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11137 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11138 const Instruction *CtxI = &BB->front();
11139 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11141 if (ProvingStrictComparison) {
11142 auto ProofFn = [&](ICmpInst::Predicate P) {
11143 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11145 if (SplitAndProve(ProofFn))
11151 // Starting at the block's predecessor, climb up the predecessor chain, as long
11152 // as there are predecessors that can be found that have unique successors
11153 // leading to the original block.
11154 const Loop *ContainingLoop = LI.getLoopFor(BB);
11155 const BasicBlock *PredBB;
11156 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11157 PredBB = ContainingLoop->getLoopPredecessor();
11159 PredBB = BB->getSinglePredecessor();
11160 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11161 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11162 if (ProveViaGuard(Pair.first))
11165 const BranchInst *LoopEntryPredicate =
11166 dyn_cast<BranchInst>(Pair.first->getTerminator());
11167 if (!LoopEntryPredicate ||
11168 LoopEntryPredicate->isUnconditional())
11171 if (ProveViaCond(LoopEntryPredicate->getCondition(),
11172 LoopEntryPredicate->getSuccessor(0) != Pair.second))
11176 // Check conditions due to any @llvm.assume intrinsics.
11177 for (auto &AssumeVH : AC.assumptions()) {
11180 auto *CI = cast<CallInst>(AssumeVH);
11181 if (!DT.dominates(CI, BB))
11184 if (ProveViaCond(CI->getArgOperand(0), false))
11191 bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
11192 ICmpInst::Predicate Pred,
11195 // Interpret a null as meaning no loop, where there is obviously no guard
11196 // (interprocedural conditions notwithstanding).
11200 // Both LHS and RHS must be available at loop entry.
11201 assert(isAvailableAtLoopEntry(LHS, L) &&
11202 "LHS is not available at Loop Entry");
11203 assert(isAvailableAtLoopEntry(RHS, L) &&
11204 "RHS is not available at Loop Entry");
11206 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11209 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11212 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11214 const Value *FoundCondValue, bool Inverse,
11215 const Instruction *CtxI) {
11216 // False conditions implies anything. Do not bother analyzing it further.
11217 if (FoundCondValue ==
11218 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11221 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11225 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11227 // Recursively handle And and Or conditions.
11228 const Value *Op0, *Op1;
11229 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11231 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11232 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11233 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11235 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11236 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11239 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11240 if (!ICI) return false;
11242 // Now that we found a conditional branch that dominates the loop or controls
11243 // the loop latch. Check to see if it is the comparison we are looking for.
11244 ICmpInst::Predicate FoundPred;
11246 FoundPred = ICI->getInversePredicate();
11248 FoundPred = ICI->getPredicate();
11250 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11251 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11253 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11256 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11258 ICmpInst::Predicate FoundPred,
11259 const SCEV *FoundLHS, const SCEV *FoundRHS,
11260 const Instruction *CtxI) {
11261 // Balance the types.
11262 if (getTypeSizeInBits(LHS->getType()) <
11263 getTypeSizeInBits(FoundLHS->getType())) {
11264 // For unsigned and equality predicates, try to prove that both found
11265 // operands fit into narrow unsigned range. If so, try to prove facts in
11267 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11268 !FoundRHS->getType()->isPointerTy()) {
11269 auto *NarrowType = LHS->getType();
11270 auto *WideType = FoundLHS->getType();
11271 auto BitWidth = getTypeSizeInBits(NarrowType);
11272 const SCEV *MaxValue = getZeroExtendExpr(
11273 getConstant(APInt::getMaxValue(BitWidth)), WideType);
11274 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11276 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11278 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11279 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11280 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11281 TruncFoundRHS, CtxI))
11286 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11288 if (CmpInst::isSigned(Pred)) {
11289 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11290 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11292 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11293 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11295 } else if (getTypeSizeInBits(LHS->getType()) >
11296 getTypeSizeInBits(FoundLHS->getType())) {
11297 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11299 if (CmpInst::isSigned(FoundPred)) {
11300 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11301 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11303 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11304 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11307 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11311 bool ScalarEvolution::isImpliedCondBalancedTypes(
11312 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11313 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11314 const Instruction *CtxI) {
11315 assert(getTypeSizeInBits(LHS->getType()) ==
11316 getTypeSizeInBits(FoundLHS->getType()) &&
11317 "Types should be balanced!");
11318 // Canonicalize the query to match the way instcombine will have
11319 // canonicalized the comparison.
11320 if (SimplifyICmpOperands(Pred, LHS, RHS))
11322 return CmpInst::isTrueWhenEqual(Pred);
11323 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11324 if (FoundLHS == FoundRHS)
11325 return CmpInst::isFalseWhenEqual(FoundPred);
11327 // Check to see if we can make the LHS or RHS match.
11328 if (LHS == FoundRHS || RHS == FoundLHS) {
11329 if (isa<SCEVConstant>(RHS)) {
11330 std::swap(FoundLHS, FoundRHS);
11331 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11333 std::swap(LHS, RHS);
11334 Pred = ICmpInst::getSwappedPredicate(Pred);
11338 // Check whether the found predicate is the same as the desired predicate.
11339 if (FoundPred == Pred)
11340 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11342 // Check whether swapping the found predicate makes it the same as the
11343 // desired predicate.
11344 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11345 // We can write the implication
11346 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11347 // using one of the following ways:
11348 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11349 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11350 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11351 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11352 // Forms 1. and 2. require swapping the operands of one condition. Don't
11353 // do this if it would break canonical constant/addrec ordering.
11354 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11355 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11357 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11358 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11360 // There's no clear preference between forms 3. and 4., try both. Avoid
11361 // forming getNotSCEV of pointer values as the resulting subtract is
11363 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11364 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11365 FoundLHS, FoundRHS, CtxI))
11368 if (!FoundLHS->getType()->isPointerTy() &&
11369 !FoundRHS->getType()->isPointerTy() &&
11370 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11371 getNotSCEV(FoundRHS), CtxI))
11377 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11378 CmpInst::Predicate P2) {
11379 assert(P1 != P2 && "Handled earlier!");
11380 return CmpInst::isRelational(P2) &&
11381 P1 == CmpInst::getFlippedSignednessPredicate(P2);
11383 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11384 // Unsigned comparison is the same as signed comparison when both the
11385 // operands are non-negative or negative.
11386 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11387 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11388 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11389 // Create local copies that we can freely swap and canonicalize our
11390 // conditions to "le/lt".
11391 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11392 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11393 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11394 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11395 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11396 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11397 std::swap(CanonicalLHS, CanonicalRHS);
11398 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11400 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11402 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11403 ICmpInst::isLE(CanonicalFoundPred)) &&
11405 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11406 // Use implication:
11407 // x <u y && y >=s 0 --> x <s y.
11408 // If we can prove the left part, the right part is also proven.
11409 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11410 CanonicalRHS, CanonicalFoundLHS,
11411 CanonicalFoundRHS);
11412 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11413 // Use implication:
11414 // x <s y && y <s 0 --> x <u y.
11415 // If we can prove the left part, the right part is also proven.
11416 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11417 CanonicalRHS, CanonicalFoundLHS,
11418 CanonicalFoundRHS);
11421 // Check if we can make progress by sharpening ranges.
11422 if (FoundPred == ICmpInst::ICMP_NE &&
11423 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11425 const SCEVConstant *C = nullptr;
11426 const SCEV *V = nullptr;
11428 if (isa<SCEVConstant>(FoundLHS)) {
11429 C = cast<SCEVConstant>(FoundLHS);
11432 C = cast<SCEVConstant>(FoundRHS);
11436 // The guarding predicate tells us that C != V. If the known range
11437 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11438 // range we consider has to correspond to same signedness as the
11439 // predicate we're interested in folding.
11441 APInt Min = ICmpInst::isSigned(Pred) ?
11442 getSignedRangeMin(V) : getUnsignedRangeMin(V);
11444 if (Min == C->getAPInt()) {
11445 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11446 // This is true even if (Min + 1) wraps around -- in case of
11447 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11449 APInt SharperMin = Min + 1;
11452 case ICmpInst::ICMP_SGE:
11453 case ICmpInst::ICMP_UGE:
11454 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11455 // RHS, we're done.
11456 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11461 case ICmpInst::ICMP_SGT:
11462 case ICmpInst::ICMP_UGT:
11463 // We know from the range information that (V `Pred` Min ||
11464 // V == Min). We know from the guarding condition that !(V
11465 // == Min). This gives us
11467 // V `Pred` Min || V == Min && !(V == Min)
11470 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11472 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11476 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11477 case ICmpInst::ICMP_SLE:
11478 case ICmpInst::ICMP_ULE:
11479 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11480 LHS, V, getConstant(SharperMin), CtxI))
11484 case ICmpInst::ICMP_SLT:
11485 case ICmpInst::ICMP_ULT:
11486 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11487 LHS, V, getConstant(Min), CtxI))
11498 // Check whether the actual condition is beyond sufficient.
11499 if (FoundPred == ICmpInst::ICMP_EQ)
11500 if (ICmpInst::isTrueWhenEqual(Pred))
11501 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11503 if (Pred == ICmpInst::ICMP_NE)
11504 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11505 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11508 // Otherwise assume the worst.
11512 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11513 const SCEV *&L, const SCEV *&R,
11514 SCEV::NoWrapFlags &Flags) {
11515 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11516 if (!AE || AE->getNumOperands() != 2)
11519 L = AE->getOperand(0);
11520 R = AE->getOperand(1);
11521 Flags = AE->getNoWrapFlags();
11525 Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
11526 const SCEV *Less) {
11527 // We avoid subtracting expressions here because this function is usually
11528 // fairly deep in the call stack (i.e. is called many times).
11532 return APInt(getTypeSizeInBits(More->getType()), 0);
11534 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11535 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11536 const auto *MAR = cast<SCEVAddRecExpr>(More);
11538 if (LAR->getLoop() != MAR->getLoop())
11541 // We look at affine expressions only; not for correctness but to keep
11542 // getStepRecurrence cheap.
11543 if (!LAR->isAffine() || !MAR->isAffine())
11546 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11549 Less = LAR->getStart();
11550 More = MAR->getStart();
11555 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11556 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11557 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11561 SCEV::NoWrapFlags Flags;
11562 const SCEV *LLess = nullptr, *RLess = nullptr;
11563 const SCEV *LMore = nullptr, *RMore = nullptr;
11564 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11565 // Compare (X + C1) vs X.
11566 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11567 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11569 return -(C1->getAPInt());
11571 // Compare X vs (X + C2).
11572 if (splitBinaryAdd(More, LMore, RMore, Flags))
11573 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11575 return C2->getAPInt();
11577 // Compare (X + C1) vs (X + C2).
11578 if (C1 && C2 && RLess == RMore)
11579 return C2->getAPInt() - C1->getAPInt();
11584 bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
11585 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11586 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11587 // Try to recognize the following pattern:
11592 // FoundLHS = {Start,+,W}
11593 // context_bb: // Basic block from the same loop
11594 // known(Pred, FoundLHS, FoundRHS)
11596 // If some predicate is known in the context of a loop, it is also known on
11597 // each iteration of this loop, including the first iteration. Therefore, in
11598 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
11599 // prove the original pred using this fact.
11602 const BasicBlock *ContextBB = CtxI->getParent();
11603 // Make sure AR varies in the context block.
11604 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
11605 const Loop *L = AR->getLoop();
11606 // Make sure that context belongs to the loop and executes on 1st iteration
11607 // (if it ever executes at all).
11608 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11610 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
11612 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
11615 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
11616 const Loop *L = AR->getLoop();
11617 // Make sure that context belongs to the loop and executes on 1st iteration
11618 // (if it ever executes at all).
11619 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11621 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
11623 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
11629 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
11630 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11631 const SCEV *FoundLHS, const SCEV *FoundRHS) {
11632 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
11635 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11639 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
11640 if (!AddRecFoundLHS)
11643 // We'd like to let SCEV reason about control dependencies, so we constrain
11644 // both the inequalities to be about add recurrences on the same loop. This
11645 // way we can use isLoopEntryGuardedByCond later.
11647 const Loop *L = AddRecFoundLHS->getLoop();
11648 if (L != AddRecLHS->getLoop())
11651 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
11653 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
11656 // Informal proof for (2), assuming (1) [*]:
11658 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
11662 // FoundLHS s< FoundRHS s< INT_MIN - C
11663 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
11664 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
11665 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
11666 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
11667 // <=> FoundLHS + C s< FoundRHS + C
11669 // [*]: (1) can be proved by ruling out overflow.
11671 // [**]: This can be proved by analyzing all the four possibilities:
11672 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
11673 // (A s>= 0, B s>= 0).
11676 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
11677 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
11678 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
11679 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
11680 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
11683 Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
11684 Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
11685 if (!LDiff || !RDiff || *LDiff != *RDiff)
11688 if (LDiff->isMinValue())
11691 APInt FoundRHSLimit;
11693 if (Pred == CmpInst::ICMP_ULT) {
11694 FoundRHSLimit = -(*RDiff);
11696 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
11697 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
11700 // Try to prove (1) or (2), as needed.
11701 return isAvailableAtLoopEntry(FoundRHS, L) &&
11702 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
11703 getConstant(FoundRHSLimit));
11706 bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
11707 const SCEV *LHS, const SCEV *RHS,
11708 const SCEV *FoundLHS,
11709 const SCEV *FoundRHS, unsigned Depth) {
11710 const PHINode *LPhi = nullptr, *RPhi = nullptr;
11712 auto ClearOnExit = make_scope_exit([&]() {
11714 bool Erased = PendingMerges.erase(LPhi);
11715 assert(Erased && "Failed to erase LPhi!");
11719 bool Erased = PendingMerges.erase(RPhi);
11720 assert(Erased && "Failed to erase RPhi!");
11725 // Find respective Phis and check that they are not being pending.
11726 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
11727 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
11728 if (!PendingMerges.insert(Phi).second)
11732 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
11733 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
11734 // If we detect a loop of Phi nodes being processed by this method, for
11737 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
11738 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
11740 // we don't want to deal with a case that complex, so return conservative
11742 if (!PendingMerges.insert(Phi).second)
11747 // If none of LHS, RHS is a Phi, nothing to do here.
11748 if (!LPhi && !RPhi)
11751 // If there is a SCEVUnknown Phi we are interested in, make it left.
11753 std::swap(LHS, RHS);
11754 std::swap(FoundLHS, FoundRHS);
11755 std::swap(LPhi, RPhi);
11756 Pred = ICmpInst::getSwappedPredicate(Pred);
11759 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
11760 const BasicBlock *LBB = LPhi->getParent();
11761 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
11763 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
11764 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
11765 isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
11766 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
11769 if (RPhi && RPhi->getParent() == LBB) {
11770 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
11771 // If we compare two Phis from the same block, and for each entry block
11772 // the predicate is true for incoming values from this block, then the
11773 // predicate is also true for the Phis.
11774 for (const BasicBlock *IncBB : predecessors(LBB)) {
11775 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
11776 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
11777 if (!ProvedEasily(L, R))
11780 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
11781 // Case two: RHS is also a Phi from the same basic block, and it is an
11782 // AddRec. It means that there is a loop which has both AddRec and Unknown
11783 // PHIs, for it we can compare incoming values of AddRec from above the loop
11784 // and latch with their respective incoming values of LPhi.
11785 // TODO: Generalize to handle loops with many inputs in a header.
11786 if (LPhi->getNumIncomingValues() != 2) return false;
11788 auto *RLoop = RAR->getLoop();
11789 auto *Predecessor = RLoop->getLoopPredecessor();
11790 assert(Predecessor && "Loop with AddRec with no predecessor?");
11791 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
11792 if (!ProvedEasily(L1, RAR->getStart()))
11794 auto *Latch = RLoop->getLoopLatch();
11795 assert(Latch && "Loop with AddRec with no latch?");
11796 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
11797 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
11800 // In all other cases go over inputs of LHS and compare each of them to RHS,
11801 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
11802 // At this point RHS is either a non-Phi, or it is a Phi from some block
11803 // different from LBB.
11804 for (const BasicBlock *IncBB : predecessors(LBB)) {
11805 // Check that RHS is available in this block.
11806 if (!dominates(RHS, IncBB))
11808 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
11809 // Make sure L does not refer to a value from a potentially previous
11810 // iteration of a loop.
11811 if (!properlyDominates(L, IncBB))
11813 if (!ProvedEasily(L, RHS))
11820 bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
11823 const SCEV *FoundLHS,
11824 const SCEV *FoundRHS) {
11825 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
11826 // sure that we are dealing with same LHS.
11827 if (RHS == FoundRHS) {
11828 std::swap(LHS, RHS);
11829 std::swap(FoundLHS, FoundRHS);
11830 Pred = ICmpInst::getSwappedPredicate(Pred);
11832 if (LHS != FoundLHS)
11835 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
11839 Value *Shiftee, *ShiftValue;
11841 using namespace PatternMatch;
11842 if (match(SUFoundRHS->getValue(),
11843 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
11844 auto *ShifteeS = getSCEV(Shiftee);
11845 // Prove one of the following:
11846 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
11847 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
11848 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
11850 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
11851 // ---> LHS <=s RHS
11852 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
11853 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
11854 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
11855 if (isKnownNonNegative(ShifteeS))
11856 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
11862 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
11863 const SCEV *LHS, const SCEV *RHS,
11864 const SCEV *FoundLHS,
11865 const SCEV *FoundRHS,
11866 const Instruction *CtxI) {
11867 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
11870 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
11873 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
11876 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
11880 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
11881 FoundLHS, FoundRHS);
11884 /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
11885 template <typename MinMaxExprType>
11886 static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
11887 const SCEV *Candidate) {
11888 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
11892 return is_contained(MinMaxExpr->operands(), Candidate);
11895 static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
11896 ICmpInst::Predicate Pred,
11897 const SCEV *LHS, const SCEV *RHS) {
11898 // If both sides are affine addrecs for the same loop, with equal
11899 // steps, and we know the recurrences don't wrap, then we only
11900 // need to check the predicate on the starting values.
11902 if (!ICmpInst::isRelational(Pred))
11905 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
11908 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
11911 if (LAR->getLoop() != RAR->getLoop())
11913 if (!LAR->isAffine() || !RAR->isAffine())
11916 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
11919 SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
11920 SCEV::FlagNSW : SCEV::FlagNUW;
11921 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
11924 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
11927 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
11929 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
11930 ICmpInst::Predicate Pred,
11931 const SCEV *LHS, const SCEV *RHS) {
11936 case ICmpInst::ICMP_SGE:
11937 std::swap(LHS, RHS);
11939 case ICmpInst::ICMP_SLE:
11941 // min(A, ...) <= A
11942 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
11943 // A <= max(A, ...)
11944 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
11946 case ICmpInst::ICMP_UGE:
11947 std::swap(LHS, RHS);
11949 case ICmpInst::ICMP_ULE:
11951 // min(A, ...) <= A
11952 // FIXME: what about umin_seq?
11953 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
11954 // A <= max(A, ...)
11955 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
11958 llvm_unreachable("covered switch fell through?!");
11961 bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
11962 const SCEV *LHS, const SCEV *RHS,
11963 const SCEV *FoundLHS,
11964 const SCEV *FoundRHS,
11966 assert(getTypeSizeInBits(LHS->getType()) ==
11967 getTypeSizeInBits(RHS->getType()) &&
11968 "LHS and RHS have different sizes?");
11969 assert(getTypeSizeInBits(FoundLHS->getType()) ==
11970 getTypeSizeInBits(FoundRHS->getType()) &&
11971 "FoundLHS and FoundRHS have different sizes?");
11972 // We want to avoid hurting the compile time with analysis of too big trees.
11973 if (Depth > MaxSCEVOperationsImplicationDepth)
11976 // We only want to work with GT comparison so far.
11977 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
11978 Pred = CmpInst::getSwappedPredicate(Pred);
11979 std::swap(LHS, RHS);
11980 std::swap(FoundLHS, FoundRHS);
11983 // For unsigned, try to reduce it to corresponding signed comparison.
11984 if (Pred == ICmpInst::ICMP_UGT)
11985 // We can replace unsigned predicate with its signed counterpart if all
11986 // involved values are non-negative.
11987 // TODO: We could have better support for unsigned.
11988 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
11989 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
11990 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
11991 // use this fact to prove that LHS and RHS are non-negative.
11992 const SCEV *MinusOne = getMinusOne(LHS->getType());
11993 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
11995 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
11997 Pred = ICmpInst::ICMP_SGT;
12000 if (Pred != ICmpInst::ICMP_SGT)
12003 auto GetOpFromSExt = [&](const SCEV *S) {
12004 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12005 return Ext->getOperand();
12006 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12007 // the constant in some cases.
12011 // Acquire values from extensions.
12012 auto *OrigLHS = LHS;
12013 auto *OrigFoundLHS = FoundLHS;
12014 LHS = GetOpFromSExt(LHS);
12015 FoundLHS = GetOpFromSExt(FoundLHS);
12017 // Is the SGT predicate can be proved trivially or using the found context.
12018 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12019 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12020 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12021 FoundRHS, Depth + 1);
12024 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12025 // We want to avoid creation of any new non-constant SCEV. Since we are
12026 // going to compare the operands to RHS, we should be certain that we don't
12027 // need any size extensions for this. So let's decline all cases when the
12028 // sizes of types of LHS and RHS do not match.
12029 // TODO: Maybe try to get RHS from sext to catch more cases?
12030 if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType()))
12033 // Should not overflow.
12034 if (!LHSAddExpr->hasNoSignedWrap())
12037 auto *LL = LHSAddExpr->getOperand(0);
12038 auto *LR = LHSAddExpr->getOperand(1);
12039 auto *MinusOne = getMinusOne(RHS->getType());
12041 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12042 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12043 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12045 // Try to prove the following rule:
12046 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12047 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12048 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12050 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12052 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12054 using namespace llvm::PatternMatch;
12056 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12057 // Rules for division.
12058 // We are going to perform some comparisons with Denominator and its
12059 // derivative expressions. In general case, creating a SCEV for it may
12060 // lead to a complex analysis of the entire graph, and in particular it
12061 // can request trip count recalculation for the same loop. This would
12062 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12063 // this, we only want to create SCEVs that are constants in this section.
12064 // So we bail if Denominator is not a constant.
12065 if (!isa<ConstantInt>(LR))
12068 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12070 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12071 // then a SCEV for the numerator already exists and matches with FoundLHS.
12072 auto *Numerator = getExistingSCEV(LL);
12073 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12076 // Make sure that the numerator matches with FoundLHS and the denominator
12078 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12081 auto *DTy = Denominator->getType();
12082 auto *FRHSTy = FoundRHS->getType();
12083 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12084 // One of types is a pointer and another one is not. We cannot extend
12085 // them properly to a wider type, so let us just reject this case.
12086 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12087 // to avoid this check.
12091 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12092 auto *WTy = getWiderType(DTy, FRHSTy);
12093 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12094 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12096 // Try to prove the following rule:
12097 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12098 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12099 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12100 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12101 if (isKnownNonPositive(RHS) &&
12102 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12105 // Try to prove the following rule:
12106 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12107 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12108 // If we divide it by Denominator > 2, then:
12109 // 1. If FoundLHS is negative, then the result is 0.
12110 // 2. If FoundLHS is non-negative, then the result is non-negative.
12111 // Anyways, the result is non-negative.
12112 auto *MinusOne = getMinusOne(WTy);
12113 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12114 if (isKnownNegative(RHS) &&
12115 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12120 // If our expression contained SCEVUnknown Phis, and we split it down and now
12121 // need to prove something for them, try to prove the predicate for every
12122 // possible incoming values of those Phis.
12123 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12129 static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
12130 const SCEV *LHS, const SCEV *RHS) {
12131 // zext x u<= sext x, sext x s<= zext x
12133 case ICmpInst::ICMP_SGE:
12134 std::swap(LHS, RHS);
12136 case ICmpInst::ICMP_SLE: {
12137 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12138 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12139 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12140 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12144 case ICmpInst::ICMP_UGE:
12145 std::swap(LHS, RHS);
12147 case ICmpInst::ICMP_ULE: {
12148 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12149 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12150 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12151 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12162 ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12163 const SCEV *LHS, const SCEV *RHS) {
12164 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12165 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12166 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12167 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12168 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12172 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12173 const SCEV *LHS, const SCEV *RHS,
12174 const SCEV *FoundLHS,
12175 const SCEV *FoundRHS) {
12177 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12178 case ICmpInst::ICMP_EQ:
12179 case ICmpInst::ICMP_NE:
12180 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12183 case ICmpInst::ICMP_SLT:
12184 case ICmpInst::ICMP_SLE:
12185 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12186 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12189 case ICmpInst::ICMP_SGT:
12190 case ICmpInst::ICMP_SGE:
12191 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12192 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12195 case ICmpInst::ICMP_ULT:
12196 case ICmpInst::ICMP_ULE:
12197 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12198 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12201 case ICmpInst::ICMP_UGT:
12202 case ICmpInst::ICMP_UGE:
12203 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12204 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12209 // Maybe it can be proved via operations?
12210 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12216 bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12219 const SCEV *FoundLHS,
12220 const SCEV *FoundRHS) {
12221 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12222 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12223 // reduce the compile time impact of this optimization.
12226 Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12230 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12232 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12233 // antecedent "`FoundLHS` `Pred` `FoundRHS`".
12234 ConstantRange FoundLHSRange =
12235 ConstantRange::makeExactICmpRegion(Pred, ConstFoundRHS);
12237 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12238 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12240 // We can also compute the range of values for `LHS` that satisfy the
12241 // consequent, "`LHS` `Pred` `RHS`":
12242 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12243 // The antecedent implies the consequent if every value of `LHS` that
12244 // satisfies the antecedent also satisfies the consequent.
12245 return LHSRange.icmp(Pred, ConstRHS);
12248 bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12250 assert(isKnownPositive(Stride) && "Positive stride expected!");
12252 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12253 const SCEV *One = getOne(Stride->getType());
12256 APInt MaxRHS = getSignedRangeMax(RHS);
12257 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12258 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12260 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12261 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12264 APInt MaxRHS = getUnsignedRangeMax(RHS);
12265 APInt MaxValue = APInt::getMaxValue(BitWidth);
12266 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12268 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12269 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12272 bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12275 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12276 const SCEV *One = getOne(Stride->getType());
12279 APInt MinRHS = getSignedRangeMin(RHS);
12280 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12281 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12283 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12284 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12287 APInt MinRHS = getUnsignedRangeMin(RHS);
12288 APInt MinValue = APInt::getMinValue(BitWidth);
12289 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12291 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12292 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12295 const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
12296 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12297 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12298 // expression fixes the case of N=0.
12299 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12300 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12301 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12304 const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12305 const SCEV *Stride,
12309 // The logic in this function assumes we can represent a positive stride.
12310 // If we can't, the backedge-taken count must be zero.
12311 if (IsSigned && BitWidth == 1)
12312 return getZero(Stride->getType());
12314 // This code has only been closely audited for negative strides in the
12315 // unsigned comparison case, it may be correct for signed comparison, but
12316 // that needs to be established.
12317 assert((!IsSigned || !isKnownNonPositive(Stride)) &&
12318 "Stride is expected strictly positive for signed case!");
12320 // Calculate the maximum backedge count based on the range of values
12321 // permitted by Start, End, and Stride.
12323 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12326 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12328 // We assume either the stride is positive, or the backedge-taken count
12329 // is zero. So force StrideForMaxBECount to be at least one.
12330 APInt One(BitWidth, 1);
12331 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12332 : APIntOps::umax(One, MinStride);
12334 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12335 : APInt::getMaxValue(BitWidth);
12336 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12338 // Although End can be a MAX expression we estimate MaxEnd considering only
12339 // the case End = RHS of the loop termination condition. This is safe because
12340 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12342 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12343 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12345 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12346 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12347 : APIntOps::umax(MaxEnd, MinStart);
12349 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12350 getConstant(StrideForMaxBECount) /* Step */);
12353 ScalarEvolution::ExitLimit
12354 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12355 const Loop *L, bool IsSigned,
12356 bool ControlsExit, bool AllowPredicates) {
12357 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
12359 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12360 bool PredicatedIV = false;
12362 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12363 // Can we prove this loop *must* be UB if overflow of IV occurs?
12364 // Reasoning goes as follows:
12365 // * Suppose the IV did self wrap.
12366 // * If Stride evenly divides the iteration space, then once wrap
12367 // occurs, the loop must revisit the same values.
12368 // * We know that RHS is invariant, and that none of those values
12369 // caused this exit to be taken previously. Thus, this exit is
12370 // dynamically dead.
12371 // * If this is the sole exit, then a dead exit implies the loop
12372 // must be infinite if there are no abnormal exits.
12373 // * If the loop were infinite, then it must either not be mustprogress
12374 // or have side effects. Otherwise, it must be UB.
12375 // * It can't (by assumption), be UB so we have contradicted our
12376 // premise and can conclude the IV did not in fact self-wrap.
12377 if (!isLoopInvariant(RHS, L))
12380 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12381 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12384 if (!ControlsExit || !loopHasNoAbnormalExits(L))
12387 return loopIsFiniteByAssumption(L);
12391 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12392 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12393 if (AR && AR->getLoop() == L && AR->isAffine()) {
12394 auto canProveNUW = [&]() {
12395 if (!isLoopInvariant(RHS, L))
12398 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12399 // We need the sequence defined by AR to strictly increase in the
12400 // unsigned integer domain for the logic below to hold.
12403 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12404 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12405 // If RHS <=u Limit, then there must exist a value V in the sequence
12406 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12407 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12408 // overflow occurs. This limit also implies that a signed comparison
12409 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12410 // the high bits on both sides must be zero.
12411 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12412 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12413 Limit = Limit.zext(OuterBitWidth);
12414 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12416 auto Flags = AR->getNoWrapFlags();
12417 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12418 Flags = setFlags(Flags, SCEV::FlagNUW);
12420 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12421 if (AR->hasNoUnsignedWrap()) {
12422 // Emulate what getZeroExtendExpr would have done during construction
12423 // if we'd been able to infer the fact just above at that time.
12424 const SCEV *Step = AR->getStepRecurrence(*this);
12425 Type *Ty = ZExt->getType();
12426 auto *S = getAddRecExpr(
12427 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12428 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12429 IV = dyn_cast<SCEVAddRecExpr>(S);
12436 if (!IV && AllowPredicates) {
12437 // Try to make this an AddRec using runtime tests, in the first X
12438 // iterations of this loop, where X is the SCEV expression found by the
12439 // algorithm below.
12440 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12441 PredicatedIV = true;
12444 // Avoid weird loops
12445 if (!IV || IV->getLoop() != L || !IV->isAffine())
12446 return getCouldNotCompute();
12448 // A precondition of this method is that the condition being analyzed
12449 // reaches an exiting branch which dominates the latch. Given that, we can
12450 // assume that an increment which violates the nowrap specification and
12451 // produces poison must cause undefined behavior when the resulting poison
12452 // value is branched upon and thus we can conclude that the backedge is
12453 // taken no more often than would be required to produce that poison value.
12454 // Note that a well defined loop can exit on the iteration which violates
12455 // the nowrap specification if there is another exit (either explicit or
12456 // implicit/exceptional) which causes the loop to execute before the
12457 // exiting instruction we're analyzing would trigger UB.
12458 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12459 bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
12460 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
12462 const SCEV *Stride = IV->getStepRecurrence(*this);
12464 bool PositiveStride = isKnownPositive(Stride);
12466 // Avoid negative or zero stride values.
12467 if (!PositiveStride) {
12468 // We can compute the correct backedge taken count for loops with unknown
12469 // strides if we can prove that the loop is not an infinite loop with side
12470 // effects. Here's the loop structure we are trying to handle -
12476 // } while (i < end);
12478 // The backedge taken count for such loops is evaluated as -
12479 // (max(end, start + stride) - start - 1) /u stride
12481 // The additional preconditions that we need to check to prove correctness
12482 // of the above formula is as follows -
12484 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12486 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12487 // no side effects within the loop)
12488 // c) loop has a single static exit (with no abnormal exits)
12490 // Precondition a) implies that if the stride is negative, this is a single
12491 // trip loop. The backedge taken count formula reduces to zero in this case.
12493 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12494 // then a zero stride means the backedge can't be taken without executing
12495 // undefined behavior.
12497 // The positive stride case is the same as isKnownPositive(Stride) returning
12498 // true (original behavior of the function).
12500 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12501 !loopHasNoAbnormalExits(L))
12502 return getCouldNotCompute();
12504 // This bailout is protecting the logic in computeMaxBECountForLT which
12505 // has not yet been sufficiently auditted or tested with negative strides.
12506 // We used to filter out all known-non-positive cases here, we're in the
12507 // process of being less restrictive bit by bit.
12508 if (IsSigned && isKnownNonPositive(Stride))
12509 return getCouldNotCompute();
12511 if (!isKnownNonZero(Stride)) {
12512 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12513 // if it might eventually be greater than start and if so, on which
12514 // iteration. We can't even produce a useful upper bound.
12515 if (!isLoopInvariant(RHS, L))
12516 return getCouldNotCompute();
12518 // We allow a potentially zero stride, but we need to divide by stride
12519 // below. Since the loop can't be infinite and this check must control
12520 // the sole exit, we can infer the exit must be taken on the first
12521 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12522 // we know the numerator in the divides below must be zero, so we can
12523 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12524 // and produce the right result.
12525 // FIXME: Handle the case where Stride is poison?
12526 auto wouldZeroStrideBeUB = [&]() {
12527 // Proof by contradiction. Suppose the stride were zero. If we can
12528 // prove that the backedge *is* taken on the first iteration, then since
12529 // we know this condition controls the sole exit, we must have an
12530 // infinite loop. We can't have a (well defined) infinite loop per
12531 // check just above.
12532 // Note: The (Start - Stride) term is used to get the start' term from
12533 // (start' + stride,+,stride). Remember that we only care about the
12534 // result of this expression when stride == 0 at runtime.
12535 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12536 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12538 if (!wouldZeroStrideBeUB()) {
12539 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12542 } else if (!Stride->isOne() && !NoWrap) {
12543 auto isUBOnWrap = [&]() {
12544 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12545 // follows trivially from the fact that every (un)signed-wrapped, but
12546 // not self-wrapped value must be LT than the last value before
12547 // (un)signed wrap. Since we know that last value didn't exit, nor
12548 // will any smaller one.
12549 return canAssumeNoSelfWrap(IV);
12552 // Avoid proven overflow cases: this will ensure that the backedge taken
12553 // count will not generate any unsigned overflow. Relaxed no-overflow
12554 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12555 // undefined behaviors like the case of C language.
12556 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12557 return getCouldNotCompute();
12560 // On all paths just preceeding, we established the following invariant:
12561 // IV can be assumed not to overflow up to and including the exiting
12562 // iteration. We proved this in one of two ways:
12563 // 1) We can show overflow doesn't occur before the exiting iteration
12564 // 1a) canIVOverflowOnLT, and b) step of one
12565 // 2) We can show that if overflow occurs, the loop must execute UB
12566 // before any possible exit.
12567 // Note that we have not yet proved RHS invariant (in general).
12569 const SCEV *Start = IV->getStart();
12571 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12572 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12573 // Use integer-typed versions for actual computation; we can't subtract
12574 // pointers in general.
12575 const SCEV *OrigStart = Start;
12576 const SCEV *OrigRHS = RHS;
12577 if (Start->getType()->isPointerTy()) {
12578 Start = getLosslessPtrToIntExpr(Start);
12579 if (isa<SCEVCouldNotCompute>(Start))
12582 if (RHS->getType()->isPointerTy()) {
12583 RHS = getLosslessPtrToIntExpr(RHS);
12584 if (isa<SCEVCouldNotCompute>(RHS))
12588 // When the RHS is not invariant, we do not know the end bound of the loop and
12589 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
12590 // calculate the MaxBECount, given the start, stride and max value for the end
12591 // bound of the loop (RHS), and the fact that IV does not overflow (which is
12593 if (!isLoopInvariant(RHS, L)) {
12594 const SCEV *MaxBECount = computeMaxBECountForLT(
12595 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12596 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
12597 false /*MaxOrZero*/, Predicates);
12600 // We use the expression (max(End,Start)-Start)/Stride to describe the
12601 // backedge count, as if the backedge is taken at least once max(End,Start)
12602 // is End and so the result is as above, and if not max(End,Start) is Start
12603 // so we get a backedge count of zero.
12604 const SCEV *BECount = nullptr;
12605 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
12606 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
12607 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
12608 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
12609 // Can we prove (max(RHS,Start) > Start - Stride?
12610 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
12611 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
12612 // In this case, we can use a refined formula for computing backedge taken
12613 // count. The general formula remains:
12614 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
12615 // We want to use the alternate formula:
12616 // "((End - 1) - (Start - Stride)) /u Stride"
12617 // Let's do a quick case analysis to show these are equivalent under
12618 // our precondition that max(RHS,Start) > Start - Stride.
12619 // * For RHS <= Start, the backedge-taken count must be zero.
12620 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12621 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
12622 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
12623 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
12624 // this to the stride of 1 case.
12625 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
12626 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12627 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
12628 // "((RHS - (Start - Stride) - 1) /u Stride".
12629 // Our preconditions trivially imply no overflow in that form.
12630 const SCEV *MinusOne = getMinusOne(Stride->getType());
12631 const SCEV *Numerator =
12632 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
12633 BECount = getUDivExpr(Numerator, Stride);
12636 const SCEV *BECountIfBackedgeTaken = nullptr;
12638 auto canProveRHSGreaterThanEqualStart = [&]() {
12639 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
12640 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart))
12643 // (RHS > Start - 1) implies RHS >= Start.
12644 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
12645 // "Start - 1" doesn't overflow.
12646 // * For signed comparison, if Start - 1 does overflow, it's equal
12647 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
12648 // * For unsigned comparison, if Start - 1 does overflow, it's equal
12649 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
12651 // FIXME: Should isLoopEntryGuardedByCond do this for us?
12652 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
12653 auto *StartMinusOne = getAddExpr(OrigStart,
12654 getMinusOne(OrigStart->getType()));
12655 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
12658 // If we know that RHS >= Start in the context of loop, then we know that
12659 // max(RHS, Start) = RHS at this point.
12661 if (canProveRHSGreaterThanEqualStart()) {
12664 // If RHS < Start, the backedge will be taken zero times. So in
12665 // general, we can write the backedge-taken count as:
12667 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
12669 // We convert it to the following to make it more convenient for SCEV:
12671 // ceil(max(RHS, Start) - Start) / Stride
12672 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
12674 // See what would happen if we assume the backedge is taken. This is
12675 // used to compute MaxBECount.
12676 BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
12679 // At this point, we know:
12681 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
12682 // 2. The index variable doesn't overflow.
12684 // Therefore, we know N exists such that
12685 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
12686 // doesn't overflow.
12688 // Using this information, try to prove whether the addition in
12689 // "(Start - End) + (Stride - 1)" has unsigned overflow.
12690 const SCEV *One = getOne(Stride->getType());
12691 bool MayAddOverflow = [&] {
12692 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
12693 if (StrideC->getAPInt().isPowerOf2()) {
12694 // Suppose Stride is a power of two, and Start/End are unsigned
12695 // integers. Let UMAX be the largest representable unsigned
12698 // By the preconditions of this function, we know
12699 // "(Start + Stride * N) >= End", and this doesn't overflow.
12702 // End <= (Start + Stride * N) <= UMAX
12704 // Subtracting Start from all the terms:
12706 // End - Start <= Stride * N <= UMAX - Start
12708 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
12710 // End - Start <= Stride * N <= UMAX
12712 // Stride * N is a multiple of Stride. Therefore,
12714 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
12716 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
12717 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
12719 // End - Start <= Stride * N <= UMAX - Stride - 1
12721 // Dropping the middle term:
12723 // End - Start <= UMAX - Stride - 1
12725 // Adding Stride - 1 to both sides:
12727 // (End - Start) + (Stride - 1) <= UMAX
12729 // In other words, the addition doesn't have unsigned overflow.
12731 // A similar proof works if we treat Start/End as signed values.
12732 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
12733 // use signed max instead of unsigned max. Note that we're trying
12734 // to prove a lack of unsigned overflow in either case.
12738 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
12739 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
12740 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
12741 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
12743 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
12749 const SCEV *Delta = getMinusSCEV(End, Start);
12750 if (!MayAddOverflow) {
12751 // floor((D + (S - 1)) / S)
12752 // We prefer this formulation if it's legal because it's fewer operations.
12754 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
12756 BECount = getUDivCeilSCEV(Delta, Stride);
12760 const SCEV *MaxBECount;
12761 bool MaxOrZero = false;
12762 if (isa<SCEVConstant>(BECount)) {
12763 MaxBECount = BECount;
12764 } else if (BECountIfBackedgeTaken &&
12765 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
12766 // If we know exactly how many times the backedge will be taken if it's
12767 // taken at least once, then the backedge count will either be that or
12769 MaxBECount = BECountIfBackedgeTaken;
12772 MaxBECount = computeMaxBECountForLT(
12773 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12776 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
12777 !isa<SCEVCouldNotCompute>(BECount))
12778 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
12780 return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
12783 ScalarEvolution::ExitLimit
12784 ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
12785 const Loop *L, bool IsSigned,
12786 bool ControlsExit, bool AllowPredicates) {
12787 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
12788 // We handle only IV > Invariant
12789 if (!isLoopInvariant(RHS, L))
12790 return getCouldNotCompute();
12792 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12793 if (!IV && AllowPredicates)
12794 // Try to make this an AddRec using runtime tests, in the first X
12795 // iterations of this loop, where X is the SCEV expression found by the
12796 // algorithm below.
12797 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12799 // Avoid weird loops
12800 if (!IV || IV->getLoop() != L || !IV->isAffine())
12801 return getCouldNotCompute();
12803 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12804 bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
12805 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
12807 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
12809 // Avoid negative or zero stride values
12810 if (!isKnownPositive(Stride))
12811 return getCouldNotCompute();
12813 // Avoid proven overflow cases: this will ensure that the backedge taken count
12814 // will not generate any unsigned overflow. Relaxed no-overflow conditions
12815 // exploit NoWrapFlags, allowing to optimize in presence of undefined
12816 // behaviors like the case of C language.
12817 if (!Stride->isOne() && !NoWrap)
12818 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
12819 return getCouldNotCompute();
12821 const SCEV *Start = IV->getStart();
12822 const SCEV *End = RHS;
12823 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
12824 // If we know that Start >= RHS in the context of loop, then we know that
12825 // min(RHS, Start) = RHS at this point.
12826 if (isLoopEntryGuardedByCond(
12827 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
12830 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
12833 if (Start->getType()->isPointerTy()) {
12834 Start = getLosslessPtrToIntExpr(Start);
12835 if (isa<SCEVCouldNotCompute>(Start))
12838 if (End->getType()->isPointerTy()) {
12839 End = getLosslessPtrToIntExpr(End);
12840 if (isa<SCEVCouldNotCompute>(End))
12844 // Compute ((Start - End) + (Stride - 1)) / Stride.
12845 // FIXME: This can overflow. Holding off on fixing this for now;
12846 // howManyGreaterThans will hopefully be gone soon.
12847 const SCEV *One = getOne(Stride->getType());
12848 const SCEV *BECount = getUDivExpr(
12849 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
12851 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
12852 : getUnsignedRangeMax(Start);
12854 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
12855 : getUnsignedRangeMin(Stride);
12857 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
12858 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
12859 : APInt::getMinValue(BitWidth) + (MinStride - 1);
12861 // Although End can be a MIN expression we estimate MinEnd considering only
12862 // the case End = RHS. This is safe because in the other case (Start - End)
12863 // is zero, leading to a zero maximum backedge taken count.
12865 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
12866 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
12868 const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
12870 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
12871 getConstant(MinStride));
12873 if (isa<SCEVCouldNotCompute>(MaxBECount))
12874 MaxBECount = BECount;
12876 return ExitLimit(BECount, MaxBECount, false, Predicates);
12879 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
12880 ScalarEvolution &SE) const {
12881 if (Range.isFullSet()) // Infinite loop.
12882 return SE.getCouldNotCompute();
12884 // If the start is a non-zero constant, shift the range to simplify things.
12885 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
12886 if (!SC->getValue()->isZero()) {
12887 SmallVector<const SCEV *, 4> Operands(operands());
12888 Operands[0] = SE.getZero(SC->getType());
12889 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
12890 getNoWrapFlags(FlagNW));
12891 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
12892 return ShiftedAddRec->getNumIterationsInRange(
12893 Range.subtract(SC->getAPInt()), SE);
12894 // This is strange and shouldn't happen.
12895 return SE.getCouldNotCompute();
12898 // The only time we can solve this is when we have all constant indices.
12899 // Otherwise, we cannot determine the overflow conditions.
12900 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
12901 return SE.getCouldNotCompute();
12903 // Okay at this point we know that all elements of the chrec are constants and
12904 // that the start element is zero.
12906 // First check to see if the range contains zero. If not, the first
12907 // iteration exits.
12908 unsigned BitWidth = SE.getTypeSizeInBits(getType());
12909 if (!Range.contains(APInt(BitWidth, 0)))
12910 return SE.getZero(getType());
12913 // If this is an affine expression then we have this situation:
12914 // Solve {0,+,A} in Range === Ax in Range
12916 // We know that zero is in the range. If A is positive then we know that
12917 // the upper value of the range must be the first possible exit value.
12918 // If A is negative then the lower of the range is the last possible loop
12919 // value. Also note that we already checked for a full range.
12920 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
12921 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
12923 // The exit value should be (End+A)/A.
12924 APInt ExitVal = (End + A).udiv(A);
12925 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
12927 // Evaluate at the exit value. If we really did fall out of the valid
12928 // range, then we computed our trip count, otherwise wrap around or other
12929 // things must have happened.
12930 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
12931 if (Range.contains(Val->getValue()))
12932 return SE.getCouldNotCompute(); // Something strange happened
12934 // Ensure that the previous value is in the range.
12935 assert(Range.contains(
12936 EvaluateConstantChrecAtConstant(this,
12937 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
12938 "Linear scev computation is off in a bad way!");
12939 return SE.getConstant(ExitValue);
12942 if (isQuadratic()) {
12943 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
12944 return SE.getConstant(*S);
12947 return SE.getCouldNotCompute();
12950 const SCEVAddRecExpr *
12951 SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
12952 assert(getNumOperands() > 1 && "AddRec with zero step?");
12953 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
12954 // but in this case we cannot guarantee that the value returned will be an
12955 // AddRec because SCEV does not have a fixed point where it stops
12956 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
12957 // may happen if we reach arithmetic depth limit while simplifying. So we
12958 // construct the returned value explicitly.
12959 SmallVector<const SCEV *, 3> Ops;
12960 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
12961 // (this + Step) is {A+B,+,B+C,+...,+,N}.
12962 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
12963 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
12964 // We know that the last operand is not a constant zero (otherwise it would
12965 // have been popped out earlier). This guarantees us that if the result has
12966 // the same last operand, then it will also not be popped out, meaning that
12967 // the returned value will be an AddRec.
12968 const SCEV *Last = getOperand(getNumOperands() - 1);
12969 assert(!Last->isZero() && "Recurrency with zero step?");
12970 Ops.push_back(Last);
12971 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
12972 SCEV::FlagAnyWrap));
12975 // Return true when S contains at least an undef value.
12976 bool ScalarEvolution::containsUndefs(const SCEV *S) const {
12977 return SCEVExprContains(S, [](const SCEV *S) {
12978 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
12979 return isa<UndefValue>(SU->getValue());
12984 // Return true when S contains a value that is a nullptr.
12985 bool ScalarEvolution::containsErasedValue(const SCEV *S) const {
12986 return SCEVExprContains(S, [](const SCEV *S) {
12987 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
12988 return SU->getValue() == nullptr;
12993 /// Return the size of an element read or written by Inst.
12994 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
12996 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
12997 Ty = Store->getValueOperand()->getType();
12998 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
12999 Ty = Load->getType();
13003 Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
13004 return getSizeOfExpr(ETy, Ty);
13007 //===----------------------------------------------------------------------===//
13008 // SCEVCallbackVH Class Implementation
13009 //===----------------------------------------------------------------------===//
13011 void ScalarEvolution::SCEVCallbackVH::deleted() {
13012 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13013 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13014 SE->ConstantEvolutionLoopExitValue.erase(PN);
13015 SE->eraseValueFromMap(getValPtr());
13016 // this now dangles!
13019 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13020 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13022 // Forget all the expressions associated with users of the old value,
13023 // so that future queries will recompute the expressions using the new
13025 Value *Old = getValPtr();
13026 SmallVector<User *, 16> Worklist(Old->users());
13027 SmallPtrSet<User *, 8> Visited;
13028 while (!Worklist.empty()) {
13029 User *U = Worklist.pop_back_val();
13030 // Deleting the Old value will cause this to dangle. Postpone
13031 // that until everything else is done.
13034 if (!Visited.insert(U).second)
13036 if (PHINode *PN = dyn_cast<PHINode>(U))
13037 SE->ConstantEvolutionLoopExitValue.erase(PN);
13038 SE->eraseValueFromMap(U);
13039 llvm::append_range(Worklist, U->users());
13041 // Delete the Old value.
13042 if (PHINode *PN = dyn_cast<PHINode>(Old))
13043 SE->ConstantEvolutionLoopExitValue.erase(PN);
13044 SE->eraseValueFromMap(Old);
13045 // this now dangles!
13048 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13049 : CallbackVH(V), SE(se) {}
13051 //===----------------------------------------------------------------------===//
13052 // ScalarEvolution Class Implementation
13053 //===----------------------------------------------------------------------===//
13055 ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
13056 AssumptionCache &AC, DominatorTree &DT,
13058 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
13059 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13060 LoopDispositions(64), BlockDispositions(64) {
13061 // To use guards for proving predicates, we need to scan every instruction in
13062 // relevant basic blocks, and not just terminators. Doing this is a waste of
13063 // time if the IR does not actually contain any calls to
13064 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13066 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13067 // to _add_ guards to the module when there weren't any before, and wants
13068 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13069 // efficient in lieu of being smart in that rather obscure case.
13071 auto *GuardDecl = F.getParent()->getFunction(
13072 Intrinsic::getName(Intrinsic::experimental_guard));
13073 HasGuards = GuardDecl && !GuardDecl->use_empty();
13076 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
13077 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
13078 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13079 ValueExprMap(std::move(Arg.ValueExprMap)),
13080 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13081 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13082 PendingMerges(std::move(Arg.PendingMerges)),
13083 MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
13084 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13085 PredicatedBackedgeTakenCounts(
13086 std::move(Arg.PredicatedBackedgeTakenCounts)),
13087 BECountUsers(std::move(Arg.BECountUsers)),
13088 ConstantEvolutionLoopExitValue(
13089 std::move(Arg.ConstantEvolutionLoopExitValue)),
13090 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13091 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13092 LoopDispositions(std::move(Arg.LoopDispositions)),
13093 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13094 BlockDispositions(std::move(Arg.BlockDispositions)),
13095 SCEVUsers(std::move(Arg.SCEVUsers)),
13096 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13097 SignedRanges(std::move(Arg.SignedRanges)),
13098 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13099 UniquePreds(std::move(Arg.UniquePreds)),
13100 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13101 LoopUsers(std::move(Arg.LoopUsers)),
13102 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13103 FirstUnknown(Arg.FirstUnknown) {
13104 Arg.FirstUnknown = nullptr;
13107 ScalarEvolution::~ScalarEvolution() {
13108 // Iterate through all the SCEVUnknown instances and call their
13109 // destructors, so that they release their references to their values.
13110 for (SCEVUnknown *U = FirstUnknown; U;) {
13111 SCEVUnknown *Tmp = U;
13113 Tmp->~SCEVUnknown();
13115 FirstUnknown = nullptr;
13117 ExprValueMap.clear();
13118 ValueExprMap.clear();
13120 BackedgeTakenCounts.clear();
13121 PredicatedBackedgeTakenCounts.clear();
13123 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13124 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13125 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13126 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13127 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13130 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
13131 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13134 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
13136 // Print all inner loops first
13138 PrintLoopInfo(OS, SE, I);
13141 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13144 SmallVector<BasicBlock *, 8> ExitingBlocks;
13145 L->getExitingBlocks(ExitingBlocks);
13146 if (ExitingBlocks.size() != 1)
13147 OS << "<multiple exits> ";
13149 if (SE->hasLoopInvariantBackedgeTakenCount(L))
13150 OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n";
13152 OS << "Unpredictable backedge-taken count.\n";
13154 if (ExitingBlocks.size() > 1)
13155 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13156 OS << " exit count for " << ExitingBlock->getName() << ": "
13157 << *SE->getExitCount(L, ExitingBlock) << "\n";
13161 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13164 if (!isa<SCEVCouldNotCompute>(SE->getConstantMaxBackedgeTakenCount(L))) {
13165 OS << "max backedge-taken count is " << *SE->getConstantMaxBackedgeTakenCount(L);
13166 if (SE->isBackedgeTakenCountMaxOrZero(L))
13167 OS << ", actual taken count either this or zero.";
13169 OS << "Unpredictable max backedge-taken count. ";
13174 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13177 SmallVector<const SCEVPredicate *, 4> Preds;
13178 auto PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13179 if (!isa<SCEVCouldNotCompute>(PBT)) {
13180 OS << "Predicated backedge-taken count is " << *PBT << "\n";
13181 OS << " Predicates:\n";
13182 for (auto *P : Preds)
13185 OS << "Unpredictable predicated backedge-taken count. ";
13189 if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
13191 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13193 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13197 static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
13199 case ScalarEvolution::LoopVariant:
13201 case ScalarEvolution::LoopInvariant:
13202 return "Invariant";
13203 case ScalarEvolution::LoopComputable:
13204 return "Computable";
13206 llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!");
13209 void ScalarEvolution::print(raw_ostream &OS) const {
13210 // ScalarEvolution's implementation of the print method is to print
13211 // out SCEV values of all instructions that are interesting. Doing
13212 // this potentially causes it to create new SCEV objects though,
13213 // which technically conflicts with the const qualifier. This isn't
13214 // observable from outside the class though, so casting away the
13215 // const isn't dangerous.
13216 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13218 if (ClassifyExpressions) {
13219 OS << "Classifying expressions for: ";
13220 F.printAsOperand(OS, /*PrintType=*/false);
13222 for (Instruction &I : instructions(F))
13223 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13226 const SCEV *SV = SE.getSCEV(&I);
13228 if (!isa<SCEVCouldNotCompute>(SV)) {
13230 SE.getUnsignedRange(SV).print(OS);
13232 SE.getSignedRange(SV).print(OS);
13235 const Loop *L = LI.getLoopFor(I.getParent());
13237 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13241 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13243 SE.getUnsignedRange(AtUse).print(OS);
13245 SE.getSignedRange(AtUse).print(OS);
13250 OS << "\t\t" "Exits: ";
13251 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13252 if (!SE.isLoopInvariant(ExitValue, L)) {
13253 OS << "<<Unknown>>";
13259 for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13261 OS << "\t\t" "LoopDispositions: { ";
13267 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13268 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
13271 for (auto *InnerL : depth_first(L)) {
13275 OS << "\t\t" "LoopDispositions: { ";
13281 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13282 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
13292 OS << "Determining loop execution counts for: ";
13293 F.printAsOperand(OS, /*PrintType=*/false);
13296 PrintLoopInfo(OS, &SE, I);
13299 ScalarEvolution::LoopDisposition
13300 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
13301 auto &Values = LoopDispositions[S];
13302 for (auto &V : Values) {
13303 if (V.getPointer() == L)
13306 Values.emplace_back(L, LoopVariant);
13307 LoopDisposition D = computeLoopDisposition(S, L);
13308 auto &Values2 = LoopDispositions[S];
13309 for (auto &V : llvm::reverse(Values2)) {
13310 if (V.getPointer() == L) {
13318 ScalarEvolution::LoopDisposition
13319 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13320 switch (S->getSCEVType()) {
13322 return LoopInvariant;
13327 return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
13328 case scAddRecExpr: {
13329 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13331 // If L is the addrec's loop, it's computable.
13332 if (AR->getLoop() == L)
13333 return LoopComputable;
13335 // Add recurrences are never invariant in the function-body (null loop).
13337 return LoopVariant;
13339 // Everything that is not defined at loop entry is variant.
13340 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13341 return LoopVariant;
13342 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13343 " dominate the contained loop's header?");
13345 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13346 if (AR->getLoop()->contains(L))
13347 return LoopInvariant;
13349 // This recurrence is variant w.r.t. L if any of its operands
13351 for (auto *Op : AR->operands())
13352 if (!isLoopInvariant(Op, L))
13353 return LoopVariant;
13355 // Otherwise it's loop-invariant.
13356 return LoopInvariant;
13364 case scSequentialUMinExpr: {
13365 bool HasVarying = false;
13366 for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
13367 LoopDisposition D = getLoopDisposition(Op, L);
13368 if (D == LoopVariant)
13369 return LoopVariant;
13370 if (D == LoopComputable)
13373 return HasVarying ? LoopComputable : LoopInvariant;
13376 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
13377 LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
13378 if (LD == LoopVariant)
13379 return LoopVariant;
13380 LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
13381 if (RD == LoopVariant)
13382 return LoopVariant;
13383 return (LD == LoopInvariant && RD == LoopInvariant) ?
13384 LoopInvariant : LoopComputable;
13387 // All non-instruction values are loop invariant. All instructions are loop
13388 // invariant if they are not contained in the specified loop.
13389 // Instructions are never considered invariant in the function body
13390 // (null loop) because they are defined within the "loop".
13391 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13392 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13393 return LoopInvariant;
13394 case scCouldNotCompute:
13395 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13397 llvm_unreachable("Unknown SCEV kind!");
13400 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
13401 return getLoopDisposition(S, L) == LoopInvariant;
13404 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
13405 return getLoopDisposition(S, L) == LoopComputable;
13408 ScalarEvolution::BlockDisposition
13409 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13410 auto &Values = BlockDispositions[S];
13411 for (auto &V : Values) {
13412 if (V.getPointer() == BB)
13415 Values.emplace_back(BB, DoesNotDominateBlock);
13416 BlockDisposition D = computeBlockDisposition(S, BB);
13417 auto &Values2 = BlockDispositions[S];
13418 for (auto &V : llvm::reverse(Values2)) {
13419 if (V.getPointer() == BB) {
13427 ScalarEvolution::BlockDisposition
13428 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13429 switch (S->getSCEVType()) {
13431 return ProperlyDominatesBlock;
13436 return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
13437 case scAddRecExpr: {
13438 // This uses a "dominates" query instead of "properly dominates" query
13439 // to test for proper dominance too, because the instruction which
13440 // produces the addrec's value is a PHI, and a PHI effectively properly
13441 // dominates its entire containing block.
13442 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13443 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13444 return DoesNotDominateBlock;
13446 // Fall through into SCEVNAryExpr handling.
13455 case scSequentialUMinExpr: {
13456 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
13457 bool Proper = true;
13458 for (const SCEV *NAryOp : NAry->operands()) {
13459 BlockDisposition D = getBlockDisposition(NAryOp, BB);
13460 if (D == DoesNotDominateBlock)
13461 return DoesNotDominateBlock;
13462 if (D == DominatesBlock)
13465 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13468 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
13469 const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
13470 BlockDisposition LD = getBlockDisposition(LHS, BB);
13471 if (LD == DoesNotDominateBlock)
13472 return DoesNotDominateBlock;
13473 BlockDisposition RD = getBlockDisposition(RHS, BB);
13474 if (RD == DoesNotDominateBlock)
13475 return DoesNotDominateBlock;
13476 return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
13477 ProperlyDominatesBlock : DominatesBlock;
13480 if (Instruction *I =
13481 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13482 if (I->getParent() == BB)
13483 return DominatesBlock;
13484 if (DT.properlyDominates(I->getParent(), BB))
13485 return ProperlyDominatesBlock;
13486 return DoesNotDominateBlock;
13488 return ProperlyDominatesBlock;
13489 case scCouldNotCompute:
13490 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13492 llvm_unreachable("Unknown SCEV kind!");
13495 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
13496 return getBlockDisposition(S, BB) >= DominatesBlock;
13499 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
13500 return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
13503 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
13504 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
13507 void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
13510 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13511 auto It = BECounts.find(L);
13512 if (It != BECounts.end()) {
13513 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
13514 if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
13515 auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
13516 assert(UserIt != BECountUsers.end());
13517 UserIt->second.erase({L, Predicated});
13520 BECounts.erase(It);
13524 void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
13525 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
13526 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
13528 while (!Worklist.empty()) {
13529 const SCEV *Curr = Worklist.pop_back_val();
13530 auto Users = SCEVUsers.find(Curr);
13531 if (Users != SCEVUsers.end())
13532 for (auto *User : Users->second)
13533 if (ToForget.insert(User).second)
13534 Worklist.push_back(User);
13537 for (auto *S : ToForget)
13538 forgetMemoizedResultsImpl(S);
13540 for (auto I = PredicatedSCEVRewrites.begin();
13541 I != PredicatedSCEVRewrites.end();) {
13542 std::pair<const SCEV *, const Loop *> Entry = I->first;
13543 if (ToForget.count(Entry.first))
13544 PredicatedSCEVRewrites.erase(I++);
13550 void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
13551 LoopDispositions.erase(S);
13552 BlockDispositions.erase(S);
13553 UnsignedRanges.erase(S);
13554 SignedRanges.erase(S);
13555 HasRecMap.erase(S);
13556 MinTrailingZerosCache.erase(S);
13558 auto ExprIt = ExprValueMap.find(S);
13559 if (ExprIt != ExprValueMap.end()) {
13560 for (Value *V : ExprIt->second) {
13561 auto ValueIt = ValueExprMap.find_as(V);
13562 if (ValueIt != ValueExprMap.end())
13563 ValueExprMap.erase(ValueIt);
13565 ExprValueMap.erase(ExprIt);
13568 auto ScopeIt = ValuesAtScopes.find(S);
13569 if (ScopeIt != ValuesAtScopes.end()) {
13570 for (const auto &Pair : ScopeIt->second)
13571 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
13572 erase_value(ValuesAtScopesUsers[Pair.second],
13573 std::make_pair(Pair.first, S));
13574 ValuesAtScopes.erase(ScopeIt);
13577 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
13578 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
13579 for (const auto &Pair : ScopeUserIt->second)
13580 erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
13581 ValuesAtScopesUsers.erase(ScopeUserIt);
13584 auto BEUsersIt = BECountUsers.find(S);
13585 if (BEUsersIt != BECountUsers.end()) {
13586 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
13587 auto Copy = BEUsersIt->second;
13588 for (const auto &Pair : Copy)
13589 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
13590 BECountUsers.erase(BEUsersIt);
13595 ScalarEvolution::getUsedLoops(const SCEV *S,
13596 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
13597 struct FindUsedLoops {
13598 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
13599 : LoopsUsed(LoopsUsed) {}
13600 SmallPtrSetImpl<const Loop *> &LoopsUsed;
13601 bool follow(const SCEV *S) {
13602 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
13603 LoopsUsed.insert(AR->getLoop());
13607 bool isDone() const { return false; }
13610 FindUsedLoops F(LoopsUsed);
13611 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
13614 void ScalarEvolution::getReachableBlocks(
13615 SmallPtrSetImpl<BasicBlock *> &Reachable, Function &F) {
13616 SmallVector<BasicBlock *> Worklist;
13617 Worklist.push_back(&F.getEntryBlock());
13618 while (!Worklist.empty()) {
13619 BasicBlock *BB = Worklist.pop_back_val();
13620 if (!Reachable.insert(BB).second)
13624 BasicBlock *TrueBB, *FalseBB;
13625 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
13626 m_BasicBlock(FalseBB)))) {
13627 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
13628 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
13632 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
13633 const SCEV *L = getSCEV(Cmp->getOperand(0));
13634 const SCEV *R = getSCEV(Cmp->getOperand(1));
13635 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
13636 Worklist.push_back(TrueBB);
13639 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
13641 Worklist.push_back(FalseBB);
13647 append_range(Worklist, successors(BB));
13651 void ScalarEvolution::verify() const {
13652 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13653 ScalarEvolution SE2(F, TLI, AC, DT, LI);
13655 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
13657 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
13658 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
13659 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
13661 const SCEV *visitConstant(const SCEVConstant *Constant) {
13662 return SE.getConstant(Constant->getAPInt());
13665 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
13666 return SE.getUnknown(Expr->getValue());
13669 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
13670 return SE.getCouldNotCompute();
13674 SCEVMapper SCM(SE2);
13675 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
13676 SE2.getReachableBlocks(ReachableBlocks, F);
13678 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
13679 if (containsUndefs(Old) || containsUndefs(New)) {
13680 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
13681 // not propagate undef aggressively). This means we can (and do) fail
13682 // verification in cases where a transform makes a value go from "undef"
13683 // to "undef+1" (say). The transform is fine, since in both cases the
13684 // result is "undef", but SCEV thinks the value increased by 1.
13688 // Unless VerifySCEVStrict is set, we only compare constant deltas.
13689 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
13690 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
13696 while (!LoopStack.empty()) {
13697 auto *L = LoopStack.pop_back_val();
13698 llvm::append_range(LoopStack, *L);
13700 // Only verify BECounts in reachable loops. For an unreachable loop,
13701 // any BECount is legal.
13702 if (!ReachableBlocks.contains(L->getHeader()))
13705 // Only verify cached BECounts. Computing new BECounts may change the
13706 // results of subsequent SCEV uses.
13707 auto It = BackedgeTakenCounts.find(L);
13708 if (It == BackedgeTakenCounts.end())
13712 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
13713 auto *NewBECount = SE2.getBackedgeTakenCount(L);
13715 if (CurBECount == SE2.getCouldNotCompute() ||
13716 NewBECount == SE2.getCouldNotCompute()) {
13717 // NB! This situation is legal, but is very suspicious -- whatever pass
13718 // change the loop to make a trip count go from could not compute to
13719 // computable or vice-versa *should have* invalidated SCEV. However, we
13720 // choose not to assert here (for now) since we don't want false
13725 if (SE.getTypeSizeInBits(CurBECount->getType()) >
13726 SE.getTypeSizeInBits(NewBECount->getType()))
13727 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
13728 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
13729 SE.getTypeSizeInBits(NewBECount->getType()))
13730 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
13732 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
13733 if (Delta && !Delta->isZero()) {
13734 dbgs() << "Trip Count for " << *L << " Changed!\n";
13735 dbgs() << "Old: " << *CurBECount << "\n";
13736 dbgs() << "New: " << *NewBECount << "\n";
13737 dbgs() << "Delta: " << *Delta << "\n";
13742 // Collect all valid loops currently in LoopInfo.
13743 SmallPtrSet<Loop *, 32> ValidLoops;
13744 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
13745 while (!Worklist.empty()) {
13746 Loop *L = Worklist.pop_back_val();
13747 if (ValidLoops.insert(L).second)
13748 Worklist.append(L->begin(), L->end());
13750 for (auto &KV : ValueExprMap) {
13752 // Check for SCEV expressions referencing invalid/deleted loops.
13753 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
13754 assert(ValidLoops.contains(AR->getLoop()) &&
13755 "AddRec references invalid loop");
13759 // Check that the value is also part of the reverse map.
13760 auto It = ExprValueMap.find(KV.second);
13761 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
13762 dbgs() << "Value " << *KV.first
13763 << " is in ValueExprMap but not in ExprValueMap\n";
13767 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
13768 if (!ReachableBlocks.contains(I->getParent()))
13770 const SCEV *OldSCEV = SCM.visit(KV.second);
13771 const SCEV *NewSCEV = SE2.getSCEV(I);
13772 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
13773 if (Delta && !Delta->isZero()) {
13774 dbgs() << "SCEV for value " << *I << " changed!\n"
13775 << "Old: " << *OldSCEV << "\n"
13776 << "New: " << *NewSCEV << "\n"
13777 << "Delta: " << *Delta << "\n";
13783 for (const auto &KV : ExprValueMap) {
13784 for (Value *V : KV.second) {
13785 auto It = ValueExprMap.find_as(V);
13786 if (It == ValueExprMap.end()) {
13787 dbgs() << "Value " << *V
13788 << " is in ExprValueMap but not in ValueExprMap\n";
13791 if (It->second != KV.first) {
13792 dbgs() << "Value " << *V << " mapped to " << *It->second
13793 << " rather than " << *KV.first << "\n";
13799 // Verify integrity of SCEV users.
13800 for (const auto &S : UniqueSCEVs) {
13801 SmallVector<const SCEV *, 4> Ops;
13802 collectUniqueOps(&S, Ops);
13803 for (const auto *Op : Ops) {
13804 // We do not store dependencies of constants.
13805 if (isa<SCEVConstant>(Op))
13807 auto It = SCEVUsers.find(Op);
13808 if (It != SCEVUsers.end() && It->second.count(&S))
13810 dbgs() << "Use of operand " << *Op << " by user " << S
13811 << " is not being tracked!\n";
13816 // Verify integrity of ValuesAtScopes users.
13817 for (const auto &ValueAndVec : ValuesAtScopes) {
13818 const SCEV *Value = ValueAndVec.first;
13819 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
13820 const Loop *L = LoopAndValueAtScope.first;
13821 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
13822 if (!isa<SCEVConstant>(ValueAtScope)) {
13823 auto It = ValuesAtScopesUsers.find(ValueAtScope);
13824 if (It != ValuesAtScopesUsers.end() &&
13825 is_contained(It->second, std::make_pair(L, Value)))
13827 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
13828 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
13834 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
13835 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
13836 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
13837 const Loop *L = LoopAndValue.first;
13838 const SCEV *Value = LoopAndValue.second;
13839 assert(!isa<SCEVConstant>(Value));
13840 auto It = ValuesAtScopes.find(Value);
13841 if (It != ValuesAtScopes.end() &&
13842 is_contained(It->second, std::make_pair(L, ValueAtScope)))
13844 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
13845 << *ValueAtScope << " missing in ValuesAtScopes\n";
13850 // Verify integrity of BECountUsers.
13851 auto VerifyBECountUsers = [&](bool Predicated) {
13853 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13854 for (const auto &LoopAndBEInfo : BECounts) {
13855 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
13856 if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
13857 auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
13858 if (UserIt != BECountUsers.end() &&
13859 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
13861 dbgs() << "Value " << *ENT.ExactNotTaken << " for loop "
13862 << *LoopAndBEInfo.first << " missing from BECountUsers\n";
13868 VerifyBECountUsers(/* Predicated */ false);
13869 VerifyBECountUsers(/* Predicated */ true);
13872 bool ScalarEvolution::invalidate(
13873 Function &F, const PreservedAnalyses &PA,
13874 FunctionAnalysisManager::Invalidator &Inv) {
13875 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
13876 // of its dependencies is invalidated.
13877 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
13878 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
13879 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
13880 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
13881 Inv.invalidate<LoopAnalysis>(F, PA);
13884 AnalysisKey ScalarEvolutionAnalysis::Key;
13886 ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
13887 FunctionAnalysisManager &AM) {
13888 return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
13889 AM.getResult<AssumptionAnalysis>(F),
13890 AM.getResult<DominatorTreeAnalysis>(F),
13891 AM.getResult<LoopAnalysis>(F));
13895 ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) {
13896 AM.getResult<ScalarEvolutionAnalysis>(F).verify();
13897 return PreservedAnalyses::all();
13901 ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
13902 // For compatibility with opt's -analyze feature under legacy pass manager
13903 // which was not ported to NPM. This keeps tests using
13904 // update_analyze_test_checks.py working.
13905 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
13906 << F.getName() << "':\n";
13907 AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
13908 return PreservedAnalyses::all();
13911 INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",
13912 "Scalar Evolution Analysis", false, true)
13913 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
13914 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
13915 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
13916 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
13917 INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
13918 "Scalar Evolution Analysis", false, true)
13920 char ScalarEvolutionWrapperPass::ID = 0;
13922 ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
13923 initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
13926 bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
13927 SE.reset(new ScalarEvolution(
13928 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
13929 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
13930 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
13931 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
13935 void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
13937 void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
13941 void ScalarEvolutionWrapperPass::verifyAnalysis() const {
13948 void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
13949 AU.setPreservesAll();
13950 AU.addRequiredTransitive<AssumptionCacheTracker>();
13951 AU.addRequiredTransitive<LoopInfoWrapperPass>();
13952 AU.addRequiredTransitive<DominatorTreeWrapperPass>();
13953 AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
13956 const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
13958 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
13961 const SCEVPredicate *
13962 ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred,
13963 const SCEV *LHS, const SCEV *RHS) {
13964 FoldingSetNodeID ID;
13965 assert(LHS->getType() == RHS->getType() &&
13966 "Type mismatch between LHS and RHS");
13967 // Unique this node based on the arguments
13968 ID.AddInteger(SCEVPredicate::P_Compare);
13969 ID.AddInteger(Pred);
13970 ID.AddPointer(LHS);
13971 ID.AddPointer(RHS);
13972 void *IP = nullptr;
13973 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
13975 SCEVComparePredicate *Eq = new (SCEVAllocator)
13976 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
13977 UniquePreds.InsertNode(Eq, IP);
13981 const SCEVPredicate *ScalarEvolution::getWrapPredicate(
13982 const SCEVAddRecExpr *AR,
13983 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
13984 FoldingSetNodeID ID;
13985 // Unique this node based on the arguments
13986 ID.AddInteger(SCEVPredicate::P_Wrap);
13988 ID.AddInteger(AddedFlags);
13989 void *IP = nullptr;
13990 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
13992 auto *OF = new (SCEVAllocator)
13993 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
13994 UniquePreds.InsertNode(OF, IP);
14000 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14003 /// Rewrites \p S in the context of a loop L and the SCEV predication
14004 /// infrastructure.
14006 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14007 /// equivalences present in \p Pred.
14009 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14010 /// \p NewPreds such that the result will be an AddRecExpr.
14011 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14012 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
14013 const SCEVPredicate *Pred) {
14014 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14015 return Rewriter.visit(S);
14018 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14020 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14021 for (auto *Pred : U->getPredicates())
14022 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14023 if (IPred->getLHS() == Expr &&
14024 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14025 return IPred->getRHS();
14026 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14027 if (IPred->getLHS() == Expr &&
14028 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14029 return IPred->getRHS();
14032 return convertToAddRecWithPreds(Expr);
14035 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14036 const SCEV *Operand = visit(Expr->getOperand());
14037 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14038 if (AR && AR->getLoop() == L && AR->isAffine()) {
14039 // This couldn't be folded because the operand didn't have the nuw
14040 // flag. Add the nusw flag as an assumption that we could make.
14041 const SCEV *Step = AR->getStepRecurrence(SE);
14042 Type *Ty = Expr->getType();
14043 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14044 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14045 SE.getSignExtendExpr(Step, Ty), L,
14046 AR->getNoWrapFlags());
14048 return SE.getZeroExtendExpr(Operand, Expr->getType());
14051 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14052 const SCEV *Operand = visit(Expr->getOperand());
14053 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14054 if (AR && AR->getLoop() == L && AR->isAffine()) {
14055 // This couldn't be folded because the operand didn't have the nsw
14056 // flag. Add the nssw flag as an assumption that we could make.
14057 const SCEV *Step = AR->getStepRecurrence(SE);
14058 Type *Ty = Expr->getType();
14059 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14060 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14061 SE.getSignExtendExpr(Step, Ty), L,
14062 AR->getNoWrapFlags());
14064 return SE.getSignExtendExpr(Operand, Expr->getType());
14068 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14069 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
14070 const SCEVPredicate *Pred)
14071 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14073 bool addOverflowAssumption(const SCEVPredicate *P) {
14075 // Check if we've already made this assumption.
14076 return Pred && Pred->implies(P);
14078 NewPreds->insert(P);
14082 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14083 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
14084 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14085 return addOverflowAssumption(A);
14088 // If \p Expr represents a PHINode, we try to see if it can be represented
14089 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14090 // to add this predicate as a runtime overflow check, we return the AddRec.
14091 // If \p Expr does not meet these conditions (is not a PHI node, or we
14092 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14094 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14095 if (!isa<PHINode>(Expr->getValue()))
14097 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14098 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14099 if (!PredicatedRewrite)
14101 for (auto *P : PredicatedRewrite->second){
14102 // Wrap predicates from outer loops are not supported.
14103 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14104 if (L != WP->getExpr()->getLoop())
14107 if (!addOverflowAssumption(P))
14110 return PredicatedRewrite->first;
14113 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
14114 const SCEVPredicate *Pred;
14118 } // end anonymous namespace
14121 ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
14122 const SCEVPredicate &Preds) {
14123 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14126 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
14127 const SCEV *S, const Loop *L,
14128 SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
14129 SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
14130 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14131 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14136 // Since the transformation was successful, we can now transfer the SCEV
14138 for (auto *P : TransformPreds)
14144 /// SCEV predicates
14145 SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
14146 SCEVPredicateKind Kind)
14147 : FastID(ID), Kind(Kind) {}
14149 SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
14150 const ICmpInst::Predicate Pred,
14151 const SCEV *LHS, const SCEV *RHS)
14152 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14153 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14154 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14157 bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14158 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14163 if (Pred != ICmpInst::ICMP_EQ)
14166 return Op->LHS == LHS && Op->RHS == RHS;
14169 bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14171 void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const {
14172 if (Pred == ICmpInst::ICMP_EQ)
14173 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14175 OS.indent(Depth) << "Compare predicate: " << *LHS
14176 << " " << CmpInst::getPredicateName(Pred) << ") "
14181 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
14182 const SCEVAddRecExpr *AR,
14183 IncrementWrapFlags Flags)
14184 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14186 const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14188 bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14189 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14191 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14194 bool SCEVWrapPredicate::isAlwaysTrue() const {
14195 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14196 IncrementWrapFlags IFlags = Flags;
14198 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14199 IFlags = clearFlags(IFlags, IncrementNSSW);
14201 return IFlags == IncrementAnyWrap;
14204 void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
14205 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14206 if (SCEVWrapPredicate::IncrementNUSW & getFlags())
14208 if (SCEVWrapPredicate::IncrementNSSW & getFlags())
14213 SCEVWrapPredicate::IncrementWrapFlags
14214 SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
14215 ScalarEvolution &SE) {
14216 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14217 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14219 // We can safely transfer the NSW flag as NSSW.
14220 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14221 ImpliedFlags = IncrementNSSW;
14223 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14224 // If the increment is positive, the SCEV NUW flag will also imply the
14225 // WrapPredicate NUSW flag.
14226 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14227 if (Step->getValue()->getValue().isNonNegative())
14228 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14231 return ImpliedFlags;
14234 /// Union predicates don't get cached so create a dummy set ID for it.
14235 SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
14236 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14237 for (auto *P : Preds)
14241 bool SCEVUnionPredicate::isAlwaysTrue() const {
14242 return all_of(Preds,
14243 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14246 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
14247 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14248 return all_of(Set->Preds,
14249 [this](const SCEVPredicate *I) { return this->implies(I); });
14251 return any_of(Preds,
14252 [N](const SCEVPredicate *I) { return I->implies(N); });
14255 void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
14256 for (auto Pred : Preds)
14257 Pred->print(OS, Depth);
14260 void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14261 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14262 for (auto Pred : Set->Preds)
14267 Preds.push_back(N);
14270 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
14273 SmallVector<const SCEVPredicate*, 4> Empty;
14274 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14277 void ScalarEvolution::registerUser(const SCEV *User,
14278 ArrayRef<const SCEV *> Ops) {
14279 for (auto *Op : Ops)
14280 // We do not expect that forgetting cached data for SCEVConstants will ever
14281 // open any prospects for sharpening or introduce any correctness issues,
14282 // so we don't bother storing their dependencies.
14283 if (!isa<SCEVConstant>(Op))
14284 SCEVUsers[Op].insert(User);
14287 const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
14288 const SCEV *Expr = SE.getSCEV(V);
14289 RewriteEntry &Entry = RewriteMap[Expr];
14291 // If we already have an entry and the version matches, return it.
14292 if (Entry.second && Generation == Entry.first)
14293 return Entry.second;
14295 // We found an entry but it's stale. Rewrite the stale entry
14296 // according to the current predicate.
14298 Expr = Entry.second;
14300 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14301 Entry = {Generation, NewSCEV};
14306 const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
14307 if (!BackedgeCount) {
14308 SmallVector<const SCEVPredicate *, 4> Preds;
14309 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14310 for (auto *P : Preds)
14313 return BackedgeCount;
14316 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
14317 if (Preds->implies(&Pred))
14320 auto &OldPreds = Preds->getPredicates();
14321 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14322 NewPreds.push_back(&Pred);
14323 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14324 updateGeneration();
14327 const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const {
14331 void PredicatedScalarEvolution::updateGeneration() {
14332 // If the generation number wrapped recompute everything.
14333 if (++Generation == 0) {
14334 for (auto &II : RewriteMap) {
14335 const SCEV *Rewritten = II.second.second;
14336 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14341 void PredicatedScalarEvolution::setNoOverflow(
14342 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
14343 const SCEV *Expr = getSCEV(V);
14344 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14346 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14348 // Clear the statically implied flags.
14349 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14350 addPredicate(*SE.getWrapPredicate(AR, Flags));
14352 auto II = FlagsMap.insert({V, Flags});
14354 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14357 bool PredicatedScalarEvolution::hasNoOverflow(
14358 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
14359 const SCEV *Expr = getSCEV(V);
14360 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14362 Flags = SCEVWrapPredicate::clearFlags(
14363 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14365 auto II = FlagsMap.find(V);
14367 if (II != FlagsMap.end())
14368 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14370 return Flags == SCEVWrapPredicate::IncrementAnyWrap;
14373 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
14374 const SCEV *Expr = this->getSCEV(V);
14375 SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
14376 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14381 for (auto *P : NewPreds)
14384 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14388 PredicatedScalarEvolution::PredicatedScalarEvolution(
14389 const PredicatedScalarEvolution &Init)
14390 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
14391 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
14392 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
14393 for (auto I : Init.FlagsMap)
14394 FlagsMap.insert(I);
14397 void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
14399 for (auto *BB : L.getBlocks())
14400 for (auto &I : *BB) {
14401 if (!SE.isSCEVable(I.getType()))
14404 auto *Expr = SE.getSCEV(&I);
14405 auto II = RewriteMap.find(Expr);
14407 if (II == RewriteMap.end())
14410 // Don't print things that are not interesting.
14411 if (II->second.second == Expr)
14414 OS.indent(Depth) << "[PSE]" << I << ":\n";
14415 OS.indent(Depth + 2) << *Expr << "\n";
14416 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
14420 // Match the mathematical pattern A - (A / B) * B, where A and B can be
14421 // arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
14422 // for URem with constant power-of-2 second operands.
14423 // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
14424 // 4, A / B becomes X / 8).
14425 bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
14426 const SCEV *&RHS) {
14427 // Try to match 'zext (trunc A to iB) to iY', which is used
14428 // for URem with constant power-of-2 second operands. Make sure the size of
14429 // the operand A matches the size of the whole expressions.
14430 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
14431 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
14432 LHS = Trunc->getOperand();
14433 // Bail out if the type of the LHS is larger than the type of the
14434 // expression for now.
14435 if (getTypeSizeInBits(LHS->getType()) >
14436 getTypeSizeInBits(Expr->getType()))
14438 if (LHS->getType() != Expr->getType())
14439 LHS = getZeroExtendExpr(LHS, Expr->getType());
14440 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
14441 << getTypeSizeInBits(Trunc->getType()));
14444 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
14445 if (Add == nullptr || Add->getNumOperands() != 2)
14448 const SCEV *A = Add->getOperand(1);
14449 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
14451 if (Mul == nullptr)
14454 const auto MatchURemWithDivisor = [&](const SCEV *B) {
14455 // (SomeExpr + (-(SomeExpr / B) * B)).
14456 if (Expr == getURemExpr(A, B)) {
14464 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
14465 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
14466 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14467 MatchURemWithDivisor(Mul->getOperand(2));
14469 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
14470 if (Mul->getNumOperands() == 2)
14471 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14472 MatchURemWithDivisor(Mul->getOperand(0)) ||
14473 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
14474 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
14479 ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
14480 SmallVector<BasicBlock*, 16> ExitingBlocks;
14481 L->getExitingBlocks(ExitingBlocks);
14483 // Form an expression for the maximum exit count possible for this loop. We
14484 // merge the max and exact information to approximate a version of
14485 // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
14486 SmallVector<const SCEV*, 4> ExitCounts;
14487 for (BasicBlock *ExitingBB : ExitingBlocks) {
14488 const SCEV *ExitCount = getExitCount(L, ExitingBB);
14489 if (isa<SCEVCouldNotCompute>(ExitCount))
14490 ExitCount = getExitCount(L, ExitingBB,
14491 ScalarEvolution::ConstantMaximum);
14492 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
14493 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
14494 "We should only have known counts for exiting blocks that "
14495 "dominate latch!");
14496 ExitCounts.push_back(ExitCount);
14499 if (ExitCounts.empty())
14500 return getCouldNotCompute();
14501 return getUMinFromMismatchedTypes(ExitCounts);
14504 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
14505 /// in the map. It skips AddRecExpr because we cannot guarantee that the
14506 /// replacement is loop invariant in the loop of the AddRec.
14508 /// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
14510 class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
14511 const DenseMap<const SCEV *, const SCEV *> ⤅
14514 SCEVLoopGuardRewriter(ScalarEvolution &SE,
14515 DenseMap<const SCEV *, const SCEV *> &M)
14516 : SCEVRewriteVisitor(SE), Map(M) {}
14518 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
14520 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14521 auto I = Map.find(Expr);
14522 if (I == Map.end())
14527 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14528 auto I = Map.find(Expr);
14529 if (I == Map.end())
14530 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
14536 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
14537 SmallVector<const SCEV *> ExprsToRewrite;
14538 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
14540 DenseMap<const SCEV *, const SCEV *>
14542 // WARNING: It is generally unsound to apply any wrap flags to the proposed
14543 // replacement SCEV which isn't directly implied by the structure of that
14544 // SCEV. In particular, using contextual facts to imply flags is *NOT*
14545 // legal. See the scoping rules for flags in the header to understand why.
14547 // If LHS is a constant, apply information to the other expression.
14548 if (isa<SCEVConstant>(LHS)) {
14549 std::swap(LHS, RHS);
14550 Predicate = CmpInst::getSwappedPredicate(Predicate);
14553 // Check for a condition of the form (-C1 + X < C2). InstCombine will
14554 // create this form when combining two checks of the form (X u< C2 + C1) and
14556 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
14557 &ExprsToRewrite]() {
14558 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
14559 if (!AddExpr || AddExpr->getNumOperands() != 2)
14562 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
14563 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
14564 auto *C2 = dyn_cast<SCEVConstant>(RHS);
14565 if (!C1 || !C2 || !LHSUnknown)
14569 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
14570 .sub(C1->getAPInt());
14572 // Bail out, unless we have a non-wrapping, monotonic range.
14573 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
14575 auto I = RewriteMap.find(LHSUnknown);
14576 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
14577 RewriteMap[LHSUnknown] = getUMaxExpr(
14578 getConstant(ExactRegion.getUnsignedMin()),
14579 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
14580 ExprsToRewrite.push_back(LHSUnknown);
14583 if (MatchRangeCheckIdiom())
14586 // If we have LHS == 0, check if LHS is computing a property of some unknown
14587 // SCEV %v which we can rewrite %v to express explicitly.
14588 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
14589 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
14590 RHSC->getValue()->isNullValue()) {
14591 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
14592 // explicitly express that.
14593 const SCEV *URemLHS = nullptr;
14594 const SCEV *URemRHS = nullptr;
14595 if (matchURem(LHS, URemLHS, URemRHS)) {
14596 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
14597 auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
14598 RewriteMap[LHSUnknown] = Multiple;
14599 ExprsToRewrite.push_back(LHSUnknown);
14605 // Do not apply information for constants or if RHS contains an AddRec.
14606 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
14609 // If RHS is SCEVUnknown, make sure the information is applied to it.
14610 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
14611 std::swap(LHS, RHS);
14612 Predicate = CmpInst::getSwappedPredicate(Predicate);
14615 // Limit to expressions that can be rewritten.
14616 if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
14619 // Check whether LHS has already been rewritten. In that case we want to
14620 // chain further rewrites onto the already rewritten value.
14621 auto I = RewriteMap.find(LHS);
14622 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
14624 const SCEV *RewrittenRHS = nullptr;
14625 switch (Predicate) {
14626 case CmpInst::ICMP_ULT:
14628 getUMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
14630 case CmpInst::ICMP_SLT:
14632 getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
14634 case CmpInst::ICMP_ULE:
14635 RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
14637 case CmpInst::ICMP_SLE:
14638 RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
14640 case CmpInst::ICMP_UGT:
14642 getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
14644 case CmpInst::ICMP_SGT:
14646 getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
14648 case CmpInst::ICMP_UGE:
14649 RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
14651 case CmpInst::ICMP_SGE:
14652 RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
14654 case CmpInst::ICMP_EQ:
14655 if (isa<SCEVConstant>(RHS))
14656 RewrittenRHS = RHS;
14658 case CmpInst::ICMP_NE:
14659 if (isa<SCEVConstant>(RHS) &&
14660 cast<SCEVConstant>(RHS)->getValue()->isNullValue())
14661 RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
14667 if (RewrittenRHS) {
14668 RewriteMap[LHS] = RewrittenRHS;
14669 if (LHS == RewrittenLHS)
14670 ExprsToRewrite.push_back(LHS);
14674 SmallVector<std::pair<Value *, bool>> Terms;
14675 // First, collect information from assumptions dominating the loop.
14676 for (auto &AssumeVH : AC.assumptions()) {
14679 auto *AssumeI = cast<CallInst>(AssumeVH);
14680 if (!DT.dominates(AssumeI, L->getHeader()))
14682 Terms.emplace_back(AssumeI->getOperand(0), true);
14685 // Second, collect conditions from dominating branches. Starting at the loop
14686 // predecessor, climb up the predecessor chain, as long as there are
14687 // predecessors that can be found that have unique successors leading to the
14688 // original header.
14689 // TODO: share this logic with isLoopEntryGuardedByCond.
14690 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
14691 L->getLoopPredecessor(), L->getHeader());
14692 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
14694 const BranchInst *LoopEntryPredicate =
14695 dyn_cast<BranchInst>(Pair.first->getTerminator());
14696 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
14699 Terms.emplace_back(LoopEntryPredicate->getCondition(),
14700 LoopEntryPredicate->getSuccessor(0) == Pair.second);
14703 // Now apply the information from the collected conditions to RewriteMap.
14704 // Conditions are processed in reverse order, so the earliest conditions is
14705 // processed first. This ensures the SCEVs with the shortest dependency chains
14706 // are constructed first.
14707 DenseMap<const SCEV *, const SCEV *> RewriteMap;
14708 for (auto &E : reverse(Terms)) {
14709 bool EnterIfTrue = E.second;
14710 SmallVector<Value *, 8> Worklist;
14711 SmallPtrSet<Value *, 8> Visited;
14712 Worklist.push_back(E.first);
14713 while (!Worklist.empty()) {
14714 Value *Cond = Worklist.pop_back_val();
14715 if (!Visited.insert(Cond).second)
14718 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14720 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
14721 const auto *LHS = getSCEV(Cmp->getOperand(0));
14722 const auto *RHS = getSCEV(Cmp->getOperand(1));
14723 CollectCondition(Predicate, LHS, RHS, RewriteMap);
14728 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
14729 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
14730 Worklist.push_back(L);
14731 Worklist.push_back(R);
14736 if (RewriteMap.empty())
14739 // Now that all rewrite information is collect, rewrite the collected
14740 // expressions with the information in the map. This applies information to
14741 // sub-expressions.
14742 if (ExprsToRewrite.size() > 1) {
14743 for (const SCEV *Expr : ExprsToRewrite) {
14744 const SCEV *RewriteTo = RewriteMap[Expr];
14745 RewriteMap.erase(Expr);
14746 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
14747 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
14751 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
14752 return Rewriter.visit(Expr);