1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
37 //===----------------------------------------------------------------------===//
39 // There are several good references for the techniques used in this analysis.
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 // On computational properties of chains of recurrences
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
58 //===----------------------------------------------------------------------===//
60 #include "llvm/Analysis/ScalarEvolution.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
64 #include "llvm/ADT/DepthFirstIterator.h"
65 #include "llvm/ADT/EquivalenceClasses.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/STLExtras.h"
68 #include "llvm/ADT/ScopeExit.h"
69 #include "llvm/ADT/Sequence.h"
70 #include "llvm/ADT/SmallPtrSet.h"
71 #include "llvm/ADT/SmallSet.h"
72 #include "llvm/ADT/SmallVector.h"
73 #include "llvm/ADT/Statistic.h"
74 #include "llvm/ADT/StringExtras.h"
75 #include "llvm/ADT/StringRef.h"
76 #include "llvm/Analysis/AssumptionCache.h"
77 #include "llvm/Analysis/ConstantFolding.h"
78 #include "llvm/Analysis/InstructionSimplify.h"
79 #include "llvm/Analysis/LoopInfo.h"
80 #include "llvm/Analysis/MemoryBuiltins.h"
81 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
82 #include "llvm/Analysis/TargetLibraryInfo.h"
83 #include "llvm/Analysis/ValueTracking.h"
84 #include "llvm/Config/llvm-config.h"
85 #include "llvm/IR/Argument.h"
86 #include "llvm/IR/BasicBlock.h"
87 #include "llvm/IR/CFG.h"
88 #include "llvm/IR/Constant.h"
89 #include "llvm/IR/ConstantRange.h"
90 #include "llvm/IR/Constants.h"
91 #include "llvm/IR/DataLayout.h"
92 #include "llvm/IR/DerivedTypes.h"
93 #include "llvm/IR/Dominators.h"
94 #include "llvm/IR/Function.h"
95 #include "llvm/IR/GlobalAlias.h"
96 #include "llvm/IR/GlobalValue.h"
97 #include "llvm/IR/InstIterator.h"
98 #include "llvm/IR/InstrTypes.h"
99 #include "llvm/IR/Instruction.h"
100 #include "llvm/IR/Instructions.h"
101 #include "llvm/IR/IntrinsicInst.h"
102 #include "llvm/IR/Intrinsics.h"
103 #include "llvm/IR/LLVMContext.h"
104 #include "llvm/IR/Operator.h"
105 #include "llvm/IR/PatternMatch.h"
106 #include "llvm/IR/Type.h"
107 #include "llvm/IR/Use.h"
108 #include "llvm/IR/User.h"
109 #include "llvm/IR/Value.h"
110 #include "llvm/IR/Verifier.h"
111 #include "llvm/InitializePasses.h"
112 #include "llvm/Pass.h"
113 #include "llvm/Support/Casting.h"
114 #include "llvm/Support/CommandLine.h"
115 #include "llvm/Support/Compiler.h"
116 #include "llvm/Support/Debug.h"
117 #include "llvm/Support/ErrorHandling.h"
118 #include "llvm/Support/KnownBits.h"
119 #include "llvm/Support/SaveAndRestore.h"
120 #include "llvm/Support/raw_ostream.h"
134 using namespace llvm;
135 using namespace PatternMatch;
137 #define DEBUG_TYPE "scalar-evolution"
139 STATISTIC(NumExitCountsComputed,
140 "Number of loop exits with predictable exit counts");
141 STATISTIC(NumExitCountsNotComputed,
142 "Number of loop exits without predictable exit counts");
143 STATISTIC(NumBruteForceTripCountsComputed,
144 "Number of loops with trip counts computed by force");
146 #ifdef EXPENSIVE_CHECKS
147 bool llvm::VerifySCEV = true;
149 bool llvm::VerifySCEV = false;
152 static cl::opt<unsigned>
153 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154 cl::desc("Maximum number of iterations SCEV will "
155 "symbolically execute a constant "
159 static cl::opt<bool, true> VerifySCEVOpt(
160 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
162 static cl::opt<bool> VerifySCEVStrict(
163 "verify-scev-strict", cl::Hidden,
164 cl::desc("Enable stricter verification with -verify-scev is passed"));
166 static cl::opt<bool> VerifyIR(
167 "scev-verify-ir", cl::Hidden,
168 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
171 static cl::opt<unsigned> MulOpsInlineThreshold(
172 "scev-mulops-inline-threshold", cl::Hidden,
173 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
176 static cl::opt<unsigned> AddOpsInlineThreshold(
177 "scev-addops-inline-threshold", cl::Hidden,
178 cl::desc("Threshold for inlining addition operands into a SCEV"),
181 static cl::opt<unsigned> MaxSCEVCompareDepth(
182 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
183 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
186 static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
187 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
188 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
191 static cl::opt<unsigned> MaxValueCompareDepth(
192 "scalar-evolution-max-value-compare-depth", cl::Hidden,
193 cl::desc("Maximum depth of recursive value complexity comparisons"),
196 static cl::opt<unsigned>
197 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
198 cl::desc("Maximum depth of recursive arithmetics"),
201 static cl::opt<unsigned> MaxConstantEvolvingDepth(
202 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
203 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205 static cl::opt<unsigned>
206 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
207 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
210 static cl::opt<unsigned>
211 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
212 cl::desc("Max coefficients in AddRec during evolving"),
215 static cl::opt<unsigned>
216 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
217 cl::desc("Size of the expression which is considered huge"),
220 static cl::opt<unsigned> RangeIterThreshold(
221 "scev-range-iter-threshold", cl::Hidden,
222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
226 ClassifyExpressions("scalar-evolution-classify-expressions",
227 cl::Hidden, cl::init(true),
228 cl::desc("When printing analysis, include information on every instruction"));
230 static cl::opt<bool> UseExpensiveRangeSharpening(
231 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
233 cl::desc("Use more powerful methods of sharpening expression ranges. May "
234 "be costly in terms of compile time"));
236 static cl::opt<unsigned> MaxPhiSCCAnalysisSize(
237 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
238 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
239 "Phi strongly connected components"),
243 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
244 cl::desc("Handle <= and >= in finite loops"),
247 static cl::opt<bool> UseContextForNoWrapFlagInference(
248 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
249 cl::desc("Infer nuw/nsw flags using context where suitable"),
252 //===----------------------------------------------------------------------===//
253 // SCEV class definitions
254 //===----------------------------------------------------------------------===//
256 //===----------------------------------------------------------------------===//
257 // Implementation of the SCEV class.
260 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
261 LLVM_DUMP_METHOD void SCEV::dump() const {
267 void SCEV::print(raw_ostream &OS) const {
268 switch (getSCEVType()) {
270 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
277 const SCEV *Op = PtrToInt->getOperand();
278 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
279 << *PtrToInt->getType() << ")";
283 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
284 const SCEV *Op = Trunc->getOperand();
285 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
286 << *Trunc->getType() << ")";
290 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
291 const SCEV *Op = ZExt->getOperand();
292 OS << "(zext " << *Op->getType() << " " << *Op << " to "
293 << *ZExt->getType() << ")";
297 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
298 const SCEV *Op = SExt->getOperand();
299 OS << "(sext " << *Op->getType() << " " << *Op << " to "
300 << *SExt->getType() << ")";
304 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
305 OS << "{" << *AR->getOperand(0);
306 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
307 OS << ",+," << *AR->getOperand(i);
309 if (AR->hasNoUnsignedWrap())
311 if (AR->hasNoSignedWrap())
313 if (AR->hasNoSelfWrap() &&
314 !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
316 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
326 case scSequentialUMinExpr: {
327 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
328 const char *OpStr = nullptr;
329 switch (NAry->getSCEVType()) {
330 case scAddExpr: OpStr = " + "; break;
331 case scMulExpr: OpStr = " * "; break;
332 case scUMaxExpr: OpStr = " umax "; break;
333 case scSMaxExpr: OpStr = " smax "; break;
340 case scSequentialUMinExpr:
341 OpStr = " umin_seq ";
344 llvm_unreachable("There are no other nary expression types.");
347 ListSeparator LS(OpStr);
348 for (const SCEV *Op : NAry->operands())
351 switch (NAry->getSCEVType()) {
354 if (NAry->hasNoUnsignedWrap())
356 if (NAry->hasNoSignedWrap())
360 // Nothing to print for other nary expressions.
366 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
367 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
373 case scCouldNotCompute:
374 OS << "***COULDNOTCOMPUTE***";
377 llvm_unreachable("Unknown SCEV kind!");
380 Type *SCEV::getType() const {
381 switch (getSCEVType()) {
383 return cast<SCEVConstant>(this)->getType();
385 return cast<SCEVVScale>(this)->getType();
390 return cast<SCEVCastExpr>(this)->getType();
392 return cast<SCEVAddRecExpr>(this)->getType();
394 return cast<SCEVMulExpr>(this)->getType();
399 return cast<SCEVMinMaxExpr>(this)->getType();
400 case scSequentialUMinExpr:
401 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
403 return cast<SCEVAddExpr>(this)->getType();
405 return cast<SCEVUDivExpr>(this)->getType();
407 return cast<SCEVUnknown>(this)->getType();
408 case scCouldNotCompute:
409 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
411 llvm_unreachable("Unknown SCEV kind!");
414 ArrayRef<const SCEV *> SCEV::operands() const {
415 switch (getSCEVType()) {
424 return cast<SCEVCastExpr>(this)->operands();
432 case scSequentialUMinExpr:
433 return cast<SCEVNAryExpr>(this)->operands();
435 return cast<SCEVUDivExpr>(this)->operands();
436 case scCouldNotCompute:
437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
439 llvm_unreachable("Unknown SCEV kind!");
442 bool SCEV::isZero() const {
443 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
444 return SC->getValue()->isZero();
448 bool SCEV::isOne() const {
449 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450 return SC->getValue()->isOne();
454 bool SCEV::isAllOnesValue() const {
455 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456 return SC->getValue()->isMinusOne();
460 bool SCEV::isNonConstantNegative() const {
461 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
462 if (!Mul) return false;
464 // If there is a constant factor, it will be first.
465 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
466 if (!SC) return false;
468 // Return true if the value is negative, this matches things like (-42 * V).
469 return SC->getAPInt().isNegative();
472 SCEVCouldNotCompute::SCEVCouldNotCompute() :
473 SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
475 bool SCEVCouldNotCompute::classof(const SCEV *S) {
476 return S->getSCEVType() == scCouldNotCompute;
479 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
481 ID.AddInteger(scConstant);
484 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
485 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
486 UniqueSCEVs.InsertNode(S, IP);
490 const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
491 return getConstant(ConstantInt::get(getContext(), Val));
495 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
496 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
497 return getConstant(ConstantInt::get(ITy, V, isSigned));
500 const SCEV *ScalarEvolution::getVScale(Type *Ty) {
502 ID.AddInteger(scVScale);
505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
507 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
508 UniqueSCEVs.InsertNode(S, IP);
512 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
513 const SCEV *op, Type *ty)
514 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
516 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
518 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
519 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
520 "Must be a non-bit-width-changing pointer-to-integer cast!");
523 SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
524 SCEVTypes SCEVTy, const SCEV *op,
526 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
528 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
530 : SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
531 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
532 "Cannot truncate non-integer value!");
535 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
536 const SCEV *op, Type *ty)
537 : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
538 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
539 "Cannot zero extend non-integer value!");
542 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
543 const SCEV *op, Type *ty)
544 : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
545 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
546 "Cannot sign extend non-integer value!");
549 void SCEVUnknown::deleted() {
550 // Clear this SCEVUnknown from various maps.
551 SE->forgetMemoizedResults(this);
553 // Remove this SCEVUnknown from the uniquing map.
554 SE->UniqueSCEVs.RemoveNode(this);
556 // Release the value.
560 void SCEVUnknown::allUsesReplacedWith(Value *New) {
561 // Clear this SCEVUnknown from various maps.
562 SE->forgetMemoizedResults(this);
564 // Remove this SCEVUnknown from the uniquing map.
565 SE->UniqueSCEVs.RemoveNode(this);
567 // Replace the value pointer in case someone is still using this SCEVUnknown.
571 //===----------------------------------------------------------------------===//
573 //===----------------------------------------------------------------------===//
575 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
576 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
577 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
578 /// have been previously deemed to be "equally complex" by this routine. It is
579 /// intended to avoid exponential time complexity in cases like:
589 /// CompareValueComplexity(%f, %c)
591 /// Since we do not continue running this routine on expression trees once we
592 /// have seen unequal values, there is no need to track them in the cache.
594 CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
595 const LoopInfo *const LI, Value *LV, Value *RV,
597 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
600 // Order pointer values after integer values. This helps SCEVExpander form
602 bool LIsPointer = LV->getType()->isPointerTy(),
603 RIsPointer = RV->getType()->isPointerTy();
604 if (LIsPointer != RIsPointer)
605 return (int)LIsPointer - (int)RIsPointer;
607 // Compare getValueID values.
608 unsigned LID = LV->getValueID(), RID = RV->getValueID();
610 return (int)LID - (int)RID;
612 // Sort arguments by their position.
613 if (const auto *LA = dyn_cast<Argument>(LV)) {
614 const auto *RA = cast<Argument>(RV);
615 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
616 return (int)LArgNo - (int)RArgNo;
619 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
620 const auto *RGV = cast<GlobalValue>(RV);
622 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
623 auto LT = GV->getLinkage();
624 return !(GlobalValue::isPrivateLinkage(LT) ||
625 GlobalValue::isInternalLinkage(LT));
628 // Use the names to distinguish the two values, but only if the
629 // names are semantically important.
630 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
631 return LGV->getName().compare(RGV->getName());
634 // For instructions, compare their loop depth, and their operand count. This
636 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
637 const auto *RInst = cast<Instruction>(RV);
639 // Compare loop depths.
640 const BasicBlock *LParent = LInst->getParent(),
641 *RParent = RInst->getParent();
642 if (LParent != RParent) {
643 unsigned LDepth = LI->getLoopDepth(LParent),
644 RDepth = LI->getLoopDepth(RParent);
645 if (LDepth != RDepth)
646 return (int)LDepth - (int)RDepth;
649 // Compare the number of operands.
650 unsigned LNumOps = LInst->getNumOperands(),
651 RNumOps = RInst->getNumOperands();
652 if (LNumOps != RNumOps)
653 return (int)LNumOps - (int)RNumOps;
655 for (unsigned Idx : seq(0u, LNumOps)) {
657 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
658 RInst->getOperand(Idx), Depth + 1);
664 EqCacheValue.unionSets(LV, RV);
668 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
669 // than RHS, respectively. A three-way result allows recursive comparisons to be
671 // If the max analysis depth was reached, return std::nullopt, assuming we do
672 // not know if they are equivalent for sure.
673 static std::optional<int>
674 CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
675 EquivalenceClasses<const Value *> &EqCacheValue,
676 const LoopInfo *const LI, const SCEV *LHS,
677 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
678 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
682 // Primarily, sort the SCEVs by their getSCEVType().
683 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
685 return (int)LType - (int)RType;
687 if (EqCacheSCEV.isEquivalent(LHS, RHS))
690 if (Depth > MaxSCEVCompareDepth)
693 // Aside from the getSCEVType() ordering, the particular ordering
694 // isn't very important except that it's beneficial to be consistent,
695 // so that (a + b) and (b + a) don't end up as different expressions.
698 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
699 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
701 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
702 RU->getValue(), Depth + 1);
704 EqCacheSCEV.unionSets(LHS, RHS);
709 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
710 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
712 // Compare constant values.
713 const APInt &LA = LC->getAPInt();
714 const APInt &RA = RC->getAPInt();
715 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
716 if (LBitWidth != RBitWidth)
717 return (int)LBitWidth - (int)RBitWidth;
718 return LA.ult(RA) ? -1 : 1;
722 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
723 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
724 return LTy->getBitWidth() - RTy->getBitWidth();
728 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
729 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
731 // There is always a dominance between two recs that are used by one SCEV,
732 // so we can safely sort recs by loop header dominance. We require such
733 // order in getAddExpr.
734 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
735 if (LLoop != RLoop) {
736 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
737 assert(LHead != RHead && "Two loops share the same header?");
738 if (DT.dominates(LHead, RHead))
740 assert(DT.dominates(RHead, LHead) &&
741 "No dominance between recurrences used by one SCEV?");
759 case scSequentialUMinExpr: {
760 ArrayRef<const SCEV *> LOps = LHS->operands();
761 ArrayRef<const SCEV *> ROps = RHS->operands();
763 // Lexicographically compare n-ary-like expressions.
764 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
765 if (LNumOps != RNumOps)
766 return (int)LNumOps - (int)RNumOps;
768 for (unsigned i = 0; i != LNumOps; ++i) {
769 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i],
770 ROps[i], DT, Depth + 1);
774 EqCacheSCEV.unionSets(LHS, RHS);
778 case scCouldNotCompute:
779 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
781 llvm_unreachable("Unknown SCEV kind!");
784 /// Given a list of SCEV objects, order them by their complexity, and group
785 /// objects of the same complexity together by value. When this routine is
786 /// finished, we know that any duplicates in the vector are consecutive and that
787 /// complexity is monotonically increasing.
789 /// Note that we go take special precautions to ensure that we get deterministic
790 /// results from this routine. In other words, we don't want the results of
791 /// this to depend on where the addresses of various SCEV objects happened to
793 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
794 LoopInfo *LI, DominatorTree &DT) {
795 if (Ops.size() < 2) return; // Noop
797 EquivalenceClasses<const SCEV *> EqCacheSCEV;
798 EquivalenceClasses<const Value *> EqCacheValue;
800 // Whether LHS has provably less complexity than RHS.
801 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
803 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
804 return Complexity && *Complexity < 0;
806 if (Ops.size() == 2) {
807 // This is the common case, which also happens to be trivially simple.
809 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
810 if (IsLessComplex(RHS, LHS))
815 // Do the rough sort by complexity.
816 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
817 return IsLessComplex(LHS, RHS);
820 // Now that we are sorted by complexity, group elements of the same
821 // complexity. Note that this is, at worst, N^2, but the vector is likely to
822 // be extremely short in practice. Note that we take this approach because we
823 // do not want to depend on the addresses of the objects we are grouping.
824 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
825 const SCEV *S = Ops[i];
826 unsigned Complexity = S->getSCEVType();
828 // If there are any objects of the same complexity and same value as this
830 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
831 if (Ops[j] == S) { // Found a duplicate.
832 // Move it to immediately after i'th element.
833 std::swap(Ops[i+1], Ops[j]);
834 ++i; // no need to rescan it.
835 if (i == e-2) return; // Done!
841 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
842 /// least HugeExprThreshold nodes).
843 static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
844 return any_of(Ops, [](const SCEV *S) {
845 return S->getExpressionSize() >= HugeExprThreshold;
849 //===----------------------------------------------------------------------===//
850 // Simple SCEV method implementations
851 //===----------------------------------------------------------------------===//
853 /// Compute BC(It, K). The result has width W. Assume, K > 0.
854 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
857 // Handle the simplest case efficiently.
859 return SE.getTruncateOrZeroExtend(It, ResultTy);
861 // We are using the following formula for BC(It, K):
863 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
865 // Suppose, W is the bitwidth of the return value. We must be prepared for
866 // overflow. Hence, we must assure that the result of our computation is
867 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
868 // safe in modular arithmetic.
870 // However, this code doesn't use exactly that formula; the formula it uses
871 // is something like the following, where T is the number of factors of 2 in
872 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
875 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
877 // This formula is trivially equivalent to the previous formula. However,
878 // this formula can be implemented much more efficiently. The trick is that
879 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
880 // arithmetic. To do exact division in modular arithmetic, all we have
881 // to do is multiply by the inverse. Therefore, this step can be done at
884 // The next issue is how to safely do the division by 2^T. The way this
885 // is done is by doing the multiplication step at a width of at least W + T
886 // bits. This way, the bottom W+T bits of the product are accurate. Then,
887 // when we perform the division by 2^T (which is equivalent to a right shift
888 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
889 // truncated out after the division by 2^T.
891 // In comparison to just directly using the first formula, this technique
892 // is much more efficient; using the first formula requires W * K bits,
893 // but this formula less than W + K bits. Also, the first formula requires
894 // a division step, whereas this formula only requires multiplies and shifts.
896 // It doesn't matter whether the subtraction step is done in the calculation
897 // width or the input iteration count's width; if the subtraction overflows,
898 // the result must be zero anyway. We prefer here to do it in the width of
899 // the induction variable because it helps a lot for certain cases; CodeGen
900 // isn't smart enough to ignore the overflow, which leads to much less
901 // efficient code if the width of the subtraction is wider than the native
904 // (It's possible to not widen at all by pulling out factors of 2 before
905 // the multiplication; for example, K=2 can be calculated as
906 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
907 // extra arithmetic, so it's not an obvious win, and it gets
908 // much more complicated for K > 3.)
910 // Protection from insane SCEVs; this bound is conservative,
911 // but it probably doesn't matter.
913 return SE.getCouldNotCompute();
915 unsigned W = SE.getTypeSizeInBits(ResultTy);
917 // Calculate K! / 2^T and T; we divide out the factors of two before
918 // multiplying for calculating K! / 2^T to avoid overflow.
919 // Other overflow doesn't matter because we only care about the bottom
920 // W bits of the result.
921 APInt OddFactorial(W, 1);
923 for (unsigned i = 3; i <= K; ++i) {
925 unsigned TwoFactors = Mult.countr_zero();
927 Mult.lshrInPlace(TwoFactors);
928 OddFactorial *= Mult;
931 // We need at least W + T bits for the multiplication step
932 unsigned CalculationBits = W + T;
934 // Calculate 2^T, at width T+W.
935 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
937 // Calculate the multiplicative inverse of K! / 2^T;
938 // this multiplication factor will perform the exact division by
940 APInt Mod = APInt::getSignedMinValue(W+1);
941 APInt MultiplyFactor = OddFactorial.zext(W+1);
942 MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
943 MultiplyFactor = MultiplyFactor.trunc(W);
945 // Calculate the product, at width T+W
946 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
948 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
949 for (unsigned i = 1; i != K; ++i) {
950 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
951 Dividend = SE.getMulExpr(Dividend,
952 SE.getTruncateOrZeroExtend(S, CalculationTy));
956 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
958 // Truncate the result, and divide by K! / 2^T.
960 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
961 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
964 /// Return the value of this chain of recurrences at the specified iteration
965 /// number. We can evaluate this recurrence by multiplying each element in the
966 /// chain by the binomial coefficient corresponding to it. In other words, we
967 /// can evaluate {A,+,B,+,C,+,D} as:
969 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
971 /// where BC(It, k) stands for binomial coefficient.
972 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
973 ScalarEvolution &SE) const {
974 return evaluateAtIteration(operands(), It, SE);
978 SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands,
979 const SCEV *It, ScalarEvolution &SE) {
980 assert(Operands.size() > 0);
981 const SCEV *Result = Operands[0];
982 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
983 // The computation is correct in the face of overflow provided that the
984 // multiplication is performed _after_ the evaluation of the binomial
986 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
987 if (isa<SCEVCouldNotCompute>(Coeff))
990 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
995 //===----------------------------------------------------------------------===//
996 // SCEV Expression folder implementations
997 //===----------------------------------------------------------------------===//
999 const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
1001 assert(Depth <= 1 &&
1002 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1004 // We could be called with an integer-typed operands during SCEV rewrites.
1005 // Since the operand is an integer already, just perform zext/trunc/self cast.
1006 if (!Op->getType()->isPointerTy())
1009 // What would be an ID for such a SCEV cast expression?
1010 FoldingSetNodeID ID;
1011 ID.AddInteger(scPtrToInt);
1016 // Is there already an expression for such a cast?
1017 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1020 // It isn't legal for optimizations to construct new ptrtoint expressions
1021 // for non-integral pointers.
1022 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1023 return getCouldNotCompute();
1025 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1027 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1028 // is sufficiently wide to represent all possible pointer values.
1029 // We could theoretically teach SCEV to truncate wider pointers, but
1030 // that isn't implemented for now.
1031 if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(Op->getType())) !=
1032 getDataLayout().getTypeSizeInBits(IntPtrTy))
1033 return getCouldNotCompute();
1035 // If not, is this expression something we can't reduce any further?
1036 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1037 // Perform some basic constant folding. If the operand of the ptr2int cast
1038 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1039 // left as-is), but produce a zero constant.
1040 // NOTE: We could handle a more general case, but lack motivational cases.
1041 if (isa<ConstantPointerNull>(U->getValue()))
1042 return getZero(IntPtrTy);
1044 // Create an explicit cast node.
1045 // We can reuse the existing insert position since if we get here,
1046 // we won't have made any changes which would invalidate it.
1047 SCEV *S = new (SCEVAllocator)
1048 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1049 UniqueSCEVs.InsertNode(S, IP);
1050 registerUser(S, Op);
1054 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1055 "non-SCEVUnknown's.");
1057 // Otherwise, we've got some expression that is more complex than just a
1058 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1059 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1060 // only, and the expressions must otherwise be integer-typed.
1061 // So sink the cast down to the SCEVUnknown's.
1063 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1064 /// which computes a pointer-typed value, and rewrites the whole expression
1065 /// tree so that *all* the computations are done on integers, and the only
1066 /// pointer-typed operands in the expression are SCEVUnknown.
1067 class SCEVPtrToIntSinkingRewriter
1068 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1069 using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
1072 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1074 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1075 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1076 return Rewriter.visit(Scev);
1079 const SCEV *visit(const SCEV *S) {
1080 Type *STy = S->getType();
1081 // If the expression is not pointer-typed, just keep it as-is.
1082 if (!STy->isPointerTy())
1084 // Else, recursively sink the cast down into it.
1085 return Base::visit(S);
1088 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1089 SmallVector<const SCEV *, 2> Operands;
1090 bool Changed = false;
1091 for (const auto *Op : Expr->operands()) {
1092 Operands.push_back(visit(Op));
1093 Changed |= Op != Operands.back();
1095 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1098 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1099 SmallVector<const SCEV *, 2> Operands;
1100 bool Changed = false;
1101 for (const auto *Op : Expr->operands()) {
1102 Operands.push_back(visit(Op));
1103 Changed |= Op != Operands.back();
1105 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1108 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1109 assert(Expr->getType()->isPointerTy() &&
1110 "Should only reach pointer-typed SCEVUnknown's.");
1111 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1115 // And actually perform the cast sinking.
1116 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1117 assert(IntOp->getType()->isIntegerTy() &&
1118 "We must have succeeded in sinking the cast, "
1119 "and ending up with an integer-typed expression!");
1123 const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) {
1124 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1126 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1127 if (isa<SCEVCouldNotCompute>(IntOp))
1130 return getTruncateOrZeroExtend(IntOp, Ty);
1133 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
1135 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1136 "This is not a truncating conversion!");
1137 assert(isSCEVable(Ty) &&
1138 "This is not a conversion to a SCEVable type!");
1139 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1140 Ty = getEffectiveSCEVType(Ty);
1142 FoldingSetNodeID ID;
1143 ID.AddInteger(scTruncate);
1147 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1149 // Fold if the operand is constant.
1150 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1152 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1154 // trunc(trunc(x)) --> trunc(x)
1155 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1156 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1158 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1159 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1160 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1162 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1163 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1164 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1166 if (Depth > MaxCastDepth) {
1168 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1169 UniqueSCEVs.InsertNode(S, IP);
1170 registerUser(S, Op);
1174 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1175 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1176 // if after transforming we have at most one truncate, not counting truncates
1177 // that replace other casts.
1178 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1179 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1180 SmallVector<const SCEV *, 4> Operands;
1181 unsigned numTruncs = 0;
1182 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1184 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1185 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1186 isa<SCEVTruncateExpr>(S))
1188 Operands.push_back(S);
1190 if (numTruncs < 2) {
1191 if (isa<SCEVAddExpr>(Op))
1192 return getAddExpr(Operands);
1193 if (isa<SCEVMulExpr>(Op))
1194 return getMulExpr(Operands);
1195 llvm_unreachable("Unexpected SCEV type for Op.");
1197 // Although we checked in the beginning that ID is not in the cache, it is
1198 // possible that during recursion and different modification ID was inserted
1199 // into the cache. So if we find it, just return it.
1200 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1204 // If the input value is a chrec scev, truncate the chrec's operands.
1205 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1206 SmallVector<const SCEV *, 4> Operands;
1207 for (const SCEV *Op : AddRec->operands())
1208 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1209 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1212 // Return zero if truncating to known zeros.
1213 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1214 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1217 // The cast wasn't folded; create an explicit cast node. We can reuse
1218 // the existing insert position since if we get here, we won't have
1219 // made any changes which would invalidate it.
1220 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1222 UniqueSCEVs.InsertNode(S, IP);
1223 registerUser(S, Op);
1227 // Get the limit of a recurrence such that incrementing by Step cannot cause
1228 // signed overflow as long as the value of the recurrence within the
1229 // loop does not exceed this limit before incrementing.
1230 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1231 ICmpInst::Predicate *Pred,
1232 ScalarEvolution *SE) {
1233 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1234 if (SE->isKnownPositive(Step)) {
1235 *Pred = ICmpInst::ICMP_SLT;
1236 return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1237 SE->getSignedRangeMax(Step));
1239 if (SE->isKnownNegative(Step)) {
1240 *Pred = ICmpInst::ICMP_SGT;
1241 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1242 SE->getSignedRangeMin(Step));
1247 // Get the limit of a recurrence such that incrementing by Step cannot cause
1248 // unsigned overflow as long as the value of the recurrence within the loop does
1249 // not exceed this limit before incrementing.
1250 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1251 ICmpInst::Predicate *Pred,
1252 ScalarEvolution *SE) {
1253 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1254 *Pred = ICmpInst::ICMP_ULT;
1256 return SE->getConstant(APInt::getMinValue(BitWidth) -
1257 SE->getUnsignedRangeMax(Step));
1262 struct ExtendOpTraitsBase {
1263 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1267 // Used to make code generic over signed and unsigned overflow.
1268 template <typename ExtendOp> struct ExtendOpTraits {
1271 // static const SCEV::NoWrapFlags WrapType;
1273 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1275 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1276 // ICmpInst::Predicate *Pred,
1277 // ScalarEvolution *SE);
1281 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1282 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1284 static const GetExtendExprTy GetExtendExpr;
1286 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1287 ICmpInst::Predicate *Pred,
1288 ScalarEvolution *SE) {
1289 return getSignedOverflowLimitForStep(Step, Pred, SE);
1293 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1294 SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1297 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1298 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1300 static const GetExtendExprTy GetExtendExpr;
1302 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1303 ICmpInst::Predicate *Pred,
1304 ScalarEvolution *SE) {
1305 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1309 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1310 SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1312 } // end anonymous namespace
1314 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1315 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1316 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1317 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1318 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1319 // expression "Step + sext/zext(PreIncAR)" is congruent with
1320 // "sext/zext(PostIncAR)"
1321 template <typename ExtendOpTy>
1322 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1323 ScalarEvolution *SE, unsigned Depth) {
1324 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1325 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1327 const Loop *L = AR->getLoop();
1328 const SCEV *Start = AR->getStart();
1329 const SCEV *Step = AR->getStepRecurrence(*SE);
1331 // Check for a simple looking step prior to loop entry.
1332 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1336 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1337 // subtraction is expensive. For this purpose, perform a quick and dirty
1338 // difference, by checking for Step in the operand list.
1339 SmallVector<const SCEV *, 4> DiffOps;
1340 for (const SCEV *Op : SA->operands())
1342 DiffOps.push_back(Op);
1344 if (DiffOps.size() == SA->getNumOperands())
1347 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1350 // 1. NSW/NUW flags on the step increment.
1351 auto PreStartFlags =
1352 ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
1353 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1354 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1355 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1357 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1358 // "S+X does not sign/unsign-overflow".
1361 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1362 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1363 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1366 // 2. Direct overflow check on the step operation's expression.
1367 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1368 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1369 const SCEV *OperandExtendedStart =
1370 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1371 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1372 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1373 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1374 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1375 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1376 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1377 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1382 // 3. Loop precondition.
1383 ICmpInst::Predicate Pred;
1384 const SCEV *OverflowLimit =
1385 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1387 if (OverflowLimit &&
1388 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1394 // Get the normalized zero or sign extended expression for this AddRec's Start.
1395 template <typename ExtendOpTy>
1396 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1397 ScalarEvolution *SE,
1399 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1401 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1403 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1405 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1407 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1410 // Try to prove away overflow by looking at "nearby" add recurrences. A
1411 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1412 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1416 // {S,+,X} == {S-T,+,X} + T
1417 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1419 // If ({S-T,+,X} + T) does not overflow ... (1)
1421 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1423 // If {S-T,+,X} does not overflow ... (2)
1425 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1426 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1428 // If (S-T)+T does not overflow ... (3)
1430 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1431 // == {Ext(S),+,Ext(X)} == LHS
1433 // Thus, if (1), (2) and (3) are true for some T, then
1434 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1436 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1437 // does not overflow" restricted to the 0th iteration. Therefore we only need
1438 // to check for (1) and (2).
1440 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1441 // is `Delta` (defined below).
1442 template <typename ExtendOpTy>
1443 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1446 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 // We restrict `Start` to a constant to prevent SCEV from spending too much
1449 // time here. It is correct (but more expensive) to continue with a
1450 // non-constant `Start` and do a general SCEV subtraction to compute
1451 // `PreStart` below.
1452 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1456 APInt StartAI = StartC->getAPInt();
1458 for (unsigned Delta : {-2, -1, 1, 2}) {
1459 const SCEV *PreStart = getConstant(StartAI - Delta);
1461 FoldingSetNodeID ID;
1462 ID.AddInteger(scAddRecExpr);
1463 ID.AddPointer(PreStart);
1464 ID.AddPointer(Step);
1468 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1470 // Give up if we don't already have the add recurrence we need because
1471 // actually constructing an add recurrence is relatively expensive.
1472 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1473 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1474 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
1475 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1476 DeltaS, &Pred, this);
1477 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1485 // Finds an integer D for an expression (C + x + y + ...) such that the top
1486 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1487 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1488 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1489 // the (C + x + y + ...) expression is \p WholeAddExpr.
1490 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1491 const SCEVConstant *ConstantTerm,
1492 const SCEVAddExpr *WholeAddExpr) {
1493 const APInt &C = ConstantTerm->getAPInt();
1494 const unsigned BitWidth = C.getBitWidth();
1495 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1496 uint32_t TZ = BitWidth;
1497 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1498 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1500 // Set D to be as many least significant bits of C as possible while still
1501 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1502 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1504 return APInt(BitWidth, 0);
1507 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1508 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1509 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1510 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1511 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1512 const APInt &ConstantStart,
1514 const unsigned BitWidth = ConstantStart.getBitWidth();
1515 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1517 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1519 return APInt(BitWidth, 0);
1522 static void insertFoldCacheEntry(
1523 const ScalarEvolution::FoldID &ID, const SCEV *S,
1524 DenseMap<ScalarEvolution::FoldID, const SCEV *> &FoldCache,
1525 DenseMap<const SCEV *, SmallVector<ScalarEvolution::FoldID, 2>>
1527 auto I = FoldCache.insert({ID, S});
1529 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1531 auto &UserIDs = FoldCacheUser[I.first->second];
1532 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1533 for (unsigned I = 0; I != UserIDs.size(); ++I)
1534 if (UserIDs[I] == ID) {
1535 std::swap(UserIDs[I], UserIDs.back());
1539 I.first->second = S;
1541 auto R = FoldCacheUser.insert({S, {}});
1542 R.first->second.push_back(ID);
1546 ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1547 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1548 "This is not an extending conversion!");
1549 assert(isSCEVable(Ty) &&
1550 "This is not a conversion to a SCEVable type!");
1551 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1552 Ty = getEffectiveSCEVType(Ty);
1554 FoldID ID(scZeroExtend, Op, Ty);
1555 auto Iter = FoldCache.find(ID);
1556 if (Iter != FoldCache.end())
1557 return Iter->second;
1559 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1560 if (!isa<SCEVZeroExtendExpr>(S))
1561 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1565 const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
1567 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1568 "This is not an extending conversion!");
1569 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1570 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1572 // Fold if the operand is constant.
1573 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1575 cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1577 // zext(zext(x)) --> zext(x)
1578 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1579 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1581 // Before doing any expensive analysis, check to see if we've already
1582 // computed a SCEV for this Op and Ty.
1583 FoldingSetNodeID ID;
1584 ID.AddInteger(scZeroExtend);
1588 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1589 if (Depth > MaxCastDepth) {
1590 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1592 UniqueSCEVs.InsertNode(S, IP);
1593 registerUser(S, Op);
1597 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1598 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1599 // It's possible the bits taken off by the truncate were all zero bits. If
1600 // so, we should be able to simplify this further.
1601 const SCEV *X = ST->getOperand();
1602 ConstantRange CR = getUnsignedRange(X);
1603 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1604 unsigned NewBits = getTypeSizeInBits(Ty);
1605 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1606 CR.zextOrTrunc(NewBits)))
1607 return getTruncateOrZeroExtend(X, Ty, Depth);
1610 // If the input value is a chrec scev, and we can prove that the value
1611 // did not overflow the old, smaller, value, we can zero extend all of the
1612 // operands (often constants). This allows analysis of something like
1613 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1614 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1615 if (AR->isAffine()) {
1616 const SCEV *Start = AR->getStart();
1617 const SCEV *Step = AR->getStepRecurrence(*this);
1618 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1619 const Loop *L = AR->getLoop();
1621 // If we have special knowledge that this addrec won't overflow,
1622 // we don't need to do any further analysis.
1623 if (AR->hasNoUnsignedWrap()) {
1625 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1626 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1627 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1630 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1631 // Note that this serves two purposes: It filters out loops that are
1632 // simply not analyzable, and it covers the case where this code is
1633 // being called from within backedge-taken count analysis, such that
1634 // attempting to ask for the backedge-taken count would likely result
1635 // in infinite recursion. In the later case, the analysis code will
1636 // cope with a conservative value, and it will take care to purge
1637 // that value once it has finished.
1638 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1639 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1640 // Manually compute the final value for AR, checking for overflow.
1642 // Check whether the backedge-taken count can be losslessly casted to
1643 // the addrec's type. The count is always unsigned.
1644 const SCEV *CastedMaxBECount =
1645 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1646 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1647 CastedMaxBECount, MaxBECount->getType(), Depth);
1648 if (MaxBECount == RecastedMaxBECount) {
1649 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1650 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1651 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1652 SCEV::FlagAnyWrap, Depth + 1);
1653 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1657 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1658 const SCEV *WideMaxBECount =
1659 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1660 const SCEV *OperandExtendedAdd =
1661 getAddExpr(WideStart,
1662 getMulExpr(WideMaxBECount,
1663 getZeroExtendExpr(Step, WideTy, Depth + 1),
1664 SCEV::FlagAnyWrap, Depth + 1),
1665 SCEV::FlagAnyWrap, Depth + 1);
1666 if (ZAdd == OperandExtendedAdd) {
1667 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1668 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1669 // Return the expression with the addrec on the outside.
1670 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1672 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1673 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1675 // Similar to above, only this time treat the step value as signed.
1676 // This covers loops that count down.
1677 OperandExtendedAdd =
1678 getAddExpr(WideStart,
1679 getMulExpr(WideMaxBECount,
1680 getSignExtendExpr(Step, WideTy, Depth + 1),
1681 SCEV::FlagAnyWrap, Depth + 1),
1682 SCEV::FlagAnyWrap, Depth + 1);
1683 if (ZAdd == OperandExtendedAdd) {
1684 // Cache knowledge of AR NW, which is propagated to this AddRec.
1685 // Negative step causes unsigned wrap, but it still can't self-wrap.
1686 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1687 // Return the expression with the addrec on the outside.
1688 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1690 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1691 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1696 // Normally, in the cases we can prove no-overflow via a
1697 // backedge guarding condition, we can also compute a backedge
1698 // taken count for the loop. The exceptions are assumptions and
1699 // guards present in the loop -- SCEV is not great at exploiting
1700 // these to compute max backedge taken counts, but can still use
1701 // these to prove lack of overflow. Use this fact to avoid
1702 // doing extra work that may not pay off.
1703 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1704 !AC.assumptions().empty()) {
1706 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1707 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1708 if (AR->hasNoUnsignedWrap()) {
1709 // Same as nuw case above - duplicated here to avoid a compile time
1710 // issue. It's not clear that the order of checks does matter, but
1711 // it's one of two issue possible causes for a change which was
1712 // reverted. Be conservative for the moment.
1714 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1715 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1716 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1719 // For a negative step, we can extend the operands iff doing so only
1720 // traverses values in the range zext([0,UINT_MAX]).
1721 if (isKnownNegative(Step)) {
1722 const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1723 getSignedRangeMin(Step));
1724 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1725 isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
1726 // Cache knowledge of AR NW, which is propagated to this
1727 // AddRec. Negative step causes unsigned wrap, but it
1728 // still can't self-wrap.
1729 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1730 // Return the expression with the addrec on the outside.
1731 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1733 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1734 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1739 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1740 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1741 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1742 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1743 const APInt &C = SC->getAPInt();
1744 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1746 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1747 const SCEV *SResidual =
1748 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1749 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1750 return getAddExpr(SZExtD, SZExtR,
1751 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1756 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1757 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1759 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1760 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1761 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1765 // zext(A % B) --> zext(A) % zext(B)
1769 if (matchURem(Op, LHS, RHS))
1770 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1771 getZeroExtendExpr(RHS, Ty, Depth + 1));
1774 // zext(A / B) --> zext(A) / zext(B).
1775 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1776 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1777 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1779 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1780 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1781 if (SA->hasNoUnsignedWrap()) {
1782 // If the addition does not unsign overflow then we can, by definition,
1783 // commute the zero extension with the addition operation.
1784 SmallVector<const SCEV *, 4> Ops;
1785 for (const auto *Op : SA->operands())
1786 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1787 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1790 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1791 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1792 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1794 // Often address arithmetics contain expressions like
1795 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1796 // This transformation is useful while proving that such expressions are
1797 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1798 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1799 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1801 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1802 const SCEV *SResidual =
1803 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1804 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1805 return getAddExpr(SZExtD, SZExtR,
1806 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1812 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1813 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1814 if (SM->hasNoUnsignedWrap()) {
1815 // If the multiply does not unsign overflow then we can, by definition,
1816 // commute the zero extension with the multiply operation.
1817 SmallVector<const SCEV *, 4> Ops;
1818 for (const auto *Op : SM->operands())
1819 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1820 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1823 // zext(2^K * (trunc X to iN)) to iM ->
1824 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1828 // zext(2^K * (trunc X to iN)) to iM
1829 // = zext((trunc X to iN) << K) to iM
1830 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1831 // (because shl removes the top K bits)
1832 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1833 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1835 if (SM->getNumOperands() == 2)
1836 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1837 if (MulLHS->getAPInt().isPowerOf2())
1838 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1839 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1840 MulLHS->getAPInt().logBase2();
1841 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1843 getZeroExtendExpr(MulLHS, Ty),
1845 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1846 SCEV::FlagNUW, Depth + 1);
1850 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1851 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1852 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1853 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1854 SmallVector<const SCEV *, 4> Operands;
1855 for (auto *Operand : MinMax->operands())
1856 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1857 if (isa<SCEVUMinExpr>(MinMax))
1858 return getUMinExpr(Operands);
1859 return getUMaxExpr(Operands);
1862 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1863 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1864 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1865 SmallVector<const SCEV *, 4> Operands;
1866 for (auto *Operand : MinMax->operands())
1867 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1868 return getUMinExpr(Operands, /*Sequential*/ true);
1871 // The cast wasn't folded; create an explicit cast node.
1872 // Recompute the insert position, as it may have been invalidated.
1873 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1874 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1876 UniqueSCEVs.InsertNode(S, IP);
1877 registerUser(S, Op);
1882 ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1883 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1884 "This is not an extending conversion!");
1885 assert(isSCEVable(Ty) &&
1886 "This is not a conversion to a SCEVable type!");
1887 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1888 Ty = getEffectiveSCEVType(Ty);
1890 FoldID ID(scSignExtend, Op, Ty);
1891 auto Iter = FoldCache.find(ID);
1892 if (Iter != FoldCache.end())
1893 return Iter->second;
1895 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1896 if (!isa<SCEVSignExtendExpr>(S))
1897 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1901 const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
1903 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1904 "This is not an extending conversion!");
1905 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1906 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1907 Ty = getEffectiveSCEVType(Ty);
1909 // Fold if the operand is constant.
1910 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1912 cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1914 // sext(sext(x)) --> sext(x)
1915 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1916 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1918 // sext(zext(x)) --> zext(x)
1919 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1920 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1922 // Before doing any expensive analysis, check to see if we've already
1923 // computed a SCEV for this Op and Ty.
1924 FoldingSetNodeID ID;
1925 ID.AddInteger(scSignExtend);
1929 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1930 // Limit recursion depth.
1931 if (Depth > MaxCastDepth) {
1932 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1934 UniqueSCEVs.InsertNode(S, IP);
1935 registerUser(S, Op);
1939 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1940 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1941 // It's possible the bits taken off by the truncate were all sign bits. If
1942 // so, we should be able to simplify this further.
1943 const SCEV *X = ST->getOperand();
1944 ConstantRange CR = getSignedRange(X);
1945 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1946 unsigned NewBits = getTypeSizeInBits(Ty);
1947 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1948 CR.sextOrTrunc(NewBits)))
1949 return getTruncateOrSignExtend(X, Ty, Depth);
1952 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1953 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1954 if (SA->hasNoSignedWrap()) {
1955 // If the addition does not sign overflow then we can, by definition,
1956 // commute the sign extension with the addition operation.
1957 SmallVector<const SCEV *, 4> Ops;
1958 for (const auto *Op : SA->operands())
1959 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1960 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1963 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1964 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1965 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1967 // For instance, this will bring two seemingly different expressions:
1968 // 1 + sext(5 + 20 * %x + 24 * %y) and
1969 // sext(6 + 20 * %x + 24 * %y)
1970 // to the same form:
1971 // 2 + sext(4 + 20 * %x + 24 * %y)
1972 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1973 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1975 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1976 const SCEV *SResidual =
1977 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1978 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1979 return getAddExpr(SSExtD, SSExtR,
1980 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1985 // If the input value is a chrec scev, and we can prove that the value
1986 // did not overflow the old, smaller, value, we can sign extend all of the
1987 // operands (often constants). This allows analysis of something like
1988 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1989 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1990 if (AR->isAffine()) {
1991 const SCEV *Start = AR->getStart();
1992 const SCEV *Step = AR->getStepRecurrence(*this);
1993 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1994 const Loop *L = AR->getLoop();
1996 // If we have special knowledge that this addrec won't overflow,
1997 // we don't need to do any further analysis.
1998 if (AR->hasNoSignedWrap()) {
2000 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2001 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2002 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2005 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2006 // Note that this serves two purposes: It filters out loops that are
2007 // simply not analyzable, and it covers the case where this code is
2008 // being called from within backedge-taken count analysis, such that
2009 // attempting to ask for the backedge-taken count would likely result
2010 // in infinite recursion. In the later case, the analysis code will
2011 // cope with a conservative value, and it will take care to purge
2012 // that value once it has finished.
2013 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2014 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2015 // Manually compute the final value for AR, checking for
2018 // Check whether the backedge-taken count can be losslessly casted to
2019 // the addrec's type. The count is always unsigned.
2020 const SCEV *CastedMaxBECount =
2021 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2022 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2023 CastedMaxBECount, MaxBECount->getType(), Depth);
2024 if (MaxBECount == RecastedMaxBECount) {
2025 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2026 // Check whether Start+Step*MaxBECount has no signed overflow.
2027 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2028 SCEV::FlagAnyWrap, Depth + 1);
2029 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2033 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2034 const SCEV *WideMaxBECount =
2035 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2036 const SCEV *OperandExtendedAdd =
2037 getAddExpr(WideStart,
2038 getMulExpr(WideMaxBECount,
2039 getSignExtendExpr(Step, WideTy, Depth + 1),
2040 SCEV::FlagAnyWrap, Depth + 1),
2041 SCEV::FlagAnyWrap, Depth + 1);
2042 if (SAdd == OperandExtendedAdd) {
2043 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2044 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2045 // Return the expression with the addrec on the outside.
2046 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2048 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2049 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2051 // Similar to above, only this time treat the step value as unsigned.
2052 // This covers loops that count up with an unsigned step.
2053 OperandExtendedAdd =
2054 getAddExpr(WideStart,
2055 getMulExpr(WideMaxBECount,
2056 getZeroExtendExpr(Step, WideTy, Depth + 1),
2057 SCEV::FlagAnyWrap, Depth + 1),
2058 SCEV::FlagAnyWrap, Depth + 1);
2059 if (SAdd == OperandExtendedAdd) {
2060 // If AR wraps around then
2062 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2063 // => SAdd != OperandExtendedAdd
2065 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2066 // (SAdd == OperandExtendedAdd => AR is NW)
2068 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2070 // Return the expression with the addrec on the outside.
2071 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2073 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2074 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2079 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2080 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2081 if (AR->hasNoSignedWrap()) {
2082 // Same as nsw case above - duplicated here to avoid a compile time
2083 // issue. It's not clear that the order of checks does matter, but
2084 // it's one of two issue possible causes for a change which was
2085 // reverted. Be conservative for the moment.
2087 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2088 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2089 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2092 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2093 // if D + (C - D + Step * n) could be proven to not signed wrap
2094 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2095 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2096 const APInt &C = SC->getAPInt();
2097 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2099 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2100 const SCEV *SResidual =
2101 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2102 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2103 return getAddExpr(SSExtD, SSExtR,
2104 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
2109 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2110 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2112 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2113 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2114 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2118 // If the input value is provably positive and we could not simplify
2119 // away the sext build a zext instead.
2120 if (isKnownNonNegative(Op))
2121 return getZeroExtendExpr(Op, Ty, Depth + 1);
2123 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2124 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2125 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2126 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2127 SmallVector<const SCEV *, 4> Operands;
2128 for (auto *Operand : MinMax->operands())
2129 Operands.push_back(getSignExtendExpr(Operand, Ty));
2130 if (isa<SCEVSMinExpr>(MinMax))
2131 return getSMinExpr(Operands);
2132 return getSMaxExpr(Operands);
2135 // The cast wasn't folded; create an explicit cast node.
2136 // Recompute the insert position, as it may have been invalidated.
2137 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2138 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2140 UniqueSCEVs.InsertNode(S, IP);
2141 registerUser(S, { Op });
2145 const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op,
2149 return getTruncateExpr(Op, Ty);
2151 return getZeroExtendExpr(Op, Ty);
2153 return getSignExtendExpr(Op, Ty);
2155 return getPtrToIntExpr(Op, Ty);
2157 llvm_unreachable("Not a SCEV cast expression!");
2161 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2162 /// unspecified bits out to the given type.
2163 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
2165 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2166 "This is not an extending conversion!");
2167 assert(isSCEVable(Ty) &&
2168 "This is not a conversion to a SCEVable type!");
2169 Ty = getEffectiveSCEVType(Ty);
2171 // Sign-extend negative constants.
2172 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2173 if (SC->getAPInt().isNegative())
2174 return getSignExtendExpr(Op, Ty);
2176 // Peel off a truncate cast.
2177 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2178 const SCEV *NewOp = T->getOperand();
2179 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2180 return getAnyExtendExpr(NewOp, Ty);
2181 return getTruncateOrNoop(NewOp, Ty);
2184 // Next try a zext cast. If the cast is folded, use it.
2185 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2186 if (!isa<SCEVZeroExtendExpr>(ZExt))
2189 // Next try a sext cast. If the cast is folded, use it.
2190 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2191 if (!isa<SCEVSignExtendExpr>(SExt))
2194 // Force the cast to be folded into the operands of an addrec.
2195 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2196 SmallVector<const SCEV *, 4> Ops;
2197 for (const SCEV *Op : AR->operands())
2198 Ops.push_back(getAnyExtendExpr(Op, Ty));
2199 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2202 // If the expression is obviously signed, use the sext cast value.
2203 if (isa<SCEVSMaxExpr>(Op))
2206 // Absent any other information, use the zext cast value.
2210 /// Process the given Ops list, which is a list of operands to be added under
2211 /// the given scale, update the given map. This is a helper function for
2212 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2213 /// that would form an add expression like this:
2215 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2217 /// where A and B are constants, update the map with these values:
2219 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2221 /// and add 13 + A*B*29 to AccumulatedConstant.
2222 /// This will allow getAddRecExpr to produce this:
2224 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2226 /// This form often exposes folding opportunities that are hidden in
2227 /// the original operand list.
2229 /// Return true iff it appears that any interesting folding opportunities
2230 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2231 /// the common case where no interesting opportunities are present, and
2232 /// is also used as a check to avoid infinite recursion.
2234 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
2235 SmallVectorImpl<const SCEV *> &NewOps,
2236 APInt &AccumulatedConstant,
2237 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2238 ScalarEvolution &SE) {
2239 bool Interesting = false;
2241 // Iterate over the add operands. They are sorted, with constants first.
2243 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2245 // Pull a buried constant out to the outside.
2246 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2248 AccumulatedConstant += Scale * C->getAPInt();
2251 // Next comes everything else. We're especially interested in multiplies
2252 // here, but they're in the middle, so just visit the rest with one loop.
2253 for (; i != Ops.size(); ++i) {
2254 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2255 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2257 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2258 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2259 // A multiplication of a constant with another add; recurse.
2260 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2262 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2263 Add->operands(), NewScale, SE);
2265 // A multiplication of a constant with some other value. Update
2267 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2268 const SCEV *Key = SE.getMulExpr(MulOps);
2269 auto Pair = M.insert({Key, NewScale});
2271 NewOps.push_back(Pair.first->first);
2273 Pair.first->second += NewScale;
2274 // The map already had an entry for this value, which may indicate
2275 // a folding opportunity.
2280 // An ordinary operand. Update the map.
2281 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2282 M.insert({Ops[i], Scale});
2284 NewOps.push_back(Pair.first->first);
2286 Pair.first->second += Scale;
2287 // The map already had an entry for this value, which may indicate
2288 // a folding opportunity.
2297 bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2298 const SCEV *LHS, const SCEV *RHS,
2299 const Instruction *CtxI) {
2300 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2301 SCEV::NoWrapFlags, unsigned);
2304 llvm_unreachable("Unsupported binary op");
2305 case Instruction::Add:
2306 Operation = &ScalarEvolution::getAddExpr;
2308 case Instruction::Sub:
2309 Operation = &ScalarEvolution::getMinusSCEV;
2311 case Instruction::Mul:
2312 Operation = &ScalarEvolution::getMulExpr;
2316 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2317 Signed ? &ScalarEvolution::getSignExtendExpr
2318 : &ScalarEvolution::getZeroExtendExpr;
2320 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2321 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2323 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2325 const SCEV *A = (this->*Extension)(
2326 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2327 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2328 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2329 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2332 // Can we use context to prove the fact we need?
2335 // TODO: Support mul.
2336 if (BinOp == Instruction::Mul)
2338 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2339 // TODO: Lift this limitation.
2342 APInt C = RHSC->getAPInt();
2343 unsigned NumBits = C.getBitWidth();
2344 bool IsSub = (BinOp == Instruction::Sub);
2345 bool IsNegativeConst = (Signed && C.isNegative());
2346 // Compute the direction and magnitude by which we need to check overflow.
2347 bool OverflowDown = IsSub ^ IsNegativeConst;
2348 APInt Magnitude = C;
2349 if (IsNegativeConst) {
2350 if (C == APInt::getSignedMinValue(NumBits))
2351 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2352 // want to deal with that.
2357 ICmpInst::Predicate Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
2359 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2360 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2361 : APInt::getMinValue(NumBits);
2362 APInt Limit = Min + Magnitude;
2363 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2365 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2366 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2367 : APInt::getMaxValue(NumBits);
2368 APInt Limit = Max - Magnitude;
2369 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2373 std::optional<SCEV::NoWrapFlags>
2374 ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
2375 const OverflowingBinaryOperator *OBO) {
2376 // It cannot be done any better.
2377 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2378 return std::nullopt;
2380 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2382 if (OBO->hasNoUnsignedWrap())
2383 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2384 if (OBO->hasNoSignedWrap())
2385 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2387 bool Deduced = false;
2389 if (OBO->getOpcode() != Instruction::Add &&
2390 OBO->getOpcode() != Instruction::Sub &&
2391 OBO->getOpcode() != Instruction::Mul)
2392 return std::nullopt;
2394 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2395 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2397 const Instruction *CtxI =
2398 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2399 if (!OBO->hasNoUnsignedWrap() &&
2400 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2401 /* Signed */ false, LHS, RHS, CtxI)) {
2402 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2406 if (!OBO->hasNoSignedWrap() &&
2407 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2408 /* Signed */ true, LHS, RHS, CtxI)) {
2409 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2415 return std::nullopt;
2418 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2419 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2420 // can't-overflow flags for the operation if possible.
2421 static SCEV::NoWrapFlags
2422 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
2423 const ArrayRef<const SCEV *> Ops,
2424 SCEV::NoWrapFlags Flags) {
2425 using namespace std::placeholders;
2427 using OBO = OverflowingBinaryOperator;
2430 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2432 assert(CanAnalyze && "don't call from other places!");
2434 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2435 SCEV::NoWrapFlags SignOrUnsignWrap =
2436 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2438 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2439 auto IsKnownNonNegative = [&](const SCEV *S) {
2440 return SE->isKnownNonNegative(S);
2443 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2445 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2447 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2449 if (SignOrUnsignWrap != SignOrUnsignMask &&
2450 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2451 isa<SCEVConstant>(Ops[0])) {
2456 return Instruction::Add;
2458 return Instruction::Mul;
2460 llvm_unreachable("Unexpected SCEV op.");
2464 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2466 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2467 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2468 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2469 Opcode, C, OBO::NoSignedWrap);
2470 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2471 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2474 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2475 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2476 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2477 Opcode, C, OBO::NoUnsignedWrap);
2478 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2479 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2483 // <0,+,nonnegative><nw> is also nuw
2484 // TODO: Add corresponding nsw case
2485 if (Type == scAddRecExpr && ScalarEvolution::hasFlags(Flags, SCEV::FlagNW) &&
2486 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2487 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2488 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2490 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2491 if (Type == scMulExpr && !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) &&
2493 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2494 if (UDiv->getOperand(1) == Ops[1])
2495 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2496 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2497 if (UDiv->getOperand(1) == Ops[0])
2498 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2504 bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
2505 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2508 /// Get a canonical add expression, or something simpler if possible.
2509 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2510 SCEV::NoWrapFlags OrigFlags,
2512 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2513 "only nuw or nsw allowed");
2514 assert(!Ops.empty() && "Cannot get empty add!");
2515 if (Ops.size() == 1) return Ops[0];
2517 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2518 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2519 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2520 "SCEVAddExpr operand types don't match!");
2521 unsigned NumPtrs = count_if(
2522 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2523 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2526 // Sort by complexity, this groups all similar expression types together.
2527 GroupByComplexity(Ops, &LI, DT);
2529 // If there are any constants, fold them together.
2531 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2533 assert(Idx < Ops.size());
2534 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2535 // We found two constants, fold them together!
2536 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2537 if (Ops.size() == 2) return Ops[0];
2538 Ops.erase(Ops.begin()+1); // Erase the folded element
2539 LHSC = cast<SCEVConstant>(Ops[0]);
2542 // If we are left with a constant zero being added, strip it off.
2543 if (LHSC->getValue()->isZero()) {
2544 Ops.erase(Ops.begin());
2548 if (Ops.size() == 1) return Ops[0];
2551 // Delay expensive flag strengthening until necessary.
2552 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2553 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2556 // Limit recursion calls depth.
2557 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2558 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2560 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2561 // Don't strengthen flags if we have no new information.
2562 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2563 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2564 Add->setNoWrapFlags(ComputeFlags(Ops));
2568 // Okay, check to see if the same value occurs in the operand list more than
2569 // once. If so, merge them together into an multiply expression. Since we
2570 // sorted the list, these values are required to be adjacent.
2571 Type *Ty = Ops[0]->getType();
2572 bool FoundMatch = false;
2573 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2574 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2575 // Scan ahead to count how many equal operands there are.
2577 while (i+Count != e && Ops[i+Count] == Ops[i])
2579 // Merge the values into a multiply.
2580 const SCEV *Scale = getConstant(Ty, Count);
2581 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2582 if (Ops.size() == Count)
2585 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2586 --i; e -= Count - 1;
2590 return getAddExpr(Ops, OrigFlags, Depth + 1);
2592 // Check for truncates. If all the operands are truncated from the same
2593 // type, see if factoring out the truncate would permit the result to be
2594 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2595 // if the contents of the resulting outer trunc fold to something simple.
2596 auto FindTruncSrcType = [&]() -> Type * {
2597 // We're ultimately looking to fold an addrec of truncs and muls of only
2598 // constants and truncs, so if we find any other types of SCEV
2599 // as operands of the addrec then we bail and return nullptr here.
2600 // Otherwise, we return the type of the operand of a trunc that we find.
2601 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2602 return T->getOperand()->getType();
2603 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2604 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2605 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2606 return T->getOperand()->getType();
2610 if (auto *SrcType = FindTruncSrcType()) {
2611 SmallVector<const SCEV *, 8> LargeOps;
2613 // Check all the operands to see if they can be represented in the
2614 // source type of the truncate.
2615 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2616 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2617 if (T->getOperand()->getType() != SrcType) {
2621 LargeOps.push_back(T->getOperand());
2622 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2623 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2624 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2625 SmallVector<const SCEV *, 8> LargeMulOps;
2626 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2627 if (const SCEVTruncateExpr *T =
2628 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2629 if (T->getOperand()->getType() != SrcType) {
2633 LargeMulOps.push_back(T->getOperand());
2634 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2635 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2642 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2649 // Evaluate the expression in the larger type.
2650 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2651 // If it folds to something simple, use it. Otherwise, don't.
2652 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2653 return getTruncateExpr(Fold, Ty);
2657 if (Ops.size() == 2) {
2658 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2659 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2661 const SCEV *A = Ops[0];
2662 const SCEV *B = Ops[1];
2663 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2664 auto *C = dyn_cast<SCEVConstant>(A);
2665 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2666 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2667 auto C2 = C->getAPInt();
2668 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2670 APInt ConstAdd = C1 + C2;
2671 auto AddFlags = AddExpr->getNoWrapFlags();
2672 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2673 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2676 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2679 // Adding a constant with the same sign and small magnitude is NSW, if the
2680 // original AddExpr was NSW.
2681 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2682 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2683 ConstAdd.abs().ule(C1.abs())) {
2685 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2688 if (PreservedFlags != SCEV::FlagAnyWrap) {
2689 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2690 NewOps[0] = getConstant(ConstAdd);
2691 return getAddExpr(NewOps, PreservedFlags);
2696 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2697 if (Ops.size() == 2) {
2698 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2699 if (Mul && Mul->getNumOperands() == 2 &&
2700 Mul->getOperand(0)->isAllOnesValue()) {
2703 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2704 return getMulExpr(Y, getUDivExpr(X, Y));
2709 // Skip past any other cast SCEVs.
2710 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2713 // If there are add operands they would be next.
2714 if (Idx < Ops.size()) {
2715 bool DeletedAdd = false;
2716 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2717 // common NUW flag for expression after inlining. Other flags cannot be
2718 // preserved, because they may depend on the original order of operations.
2719 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2720 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2721 if (Ops.size() > AddOpsInlineThreshold ||
2722 Add->getNumOperands() > AddOpsInlineThreshold)
2724 // If we have an add, expand the add operands onto the end of the operands
2726 Ops.erase(Ops.begin()+Idx);
2727 append_range(Ops, Add->operands());
2729 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2732 // If we deleted at least one add, we added operands to the end of the list,
2733 // and they are not necessarily sorted. Recurse to resort and resimplify
2734 // any operands we just acquired.
2736 return getAddExpr(Ops, CommonFlags, Depth + 1);
2739 // Skip over the add expression until we get to a multiply.
2740 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2743 // Check to see if there are any folding opportunities present with
2744 // operands multiplied by constant values.
2745 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2746 uint64_t BitWidth = getTypeSizeInBits(Ty);
2747 DenseMap<const SCEV *, APInt> M;
2748 SmallVector<const SCEV *, 8> NewOps;
2749 APInt AccumulatedConstant(BitWidth, 0);
2750 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2751 Ops, APInt(BitWidth, 1), *this)) {
2752 struct APIntCompare {
2753 bool operator()(const APInt &LHS, const APInt &RHS) const {
2754 return LHS.ult(RHS);
2758 // Some interesting folding opportunity is present, so its worthwhile to
2759 // re-generate the operands list. Group the operands by constant scale,
2760 // to avoid multiplying by the same constant scale multiple times.
2761 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2762 for (const SCEV *NewOp : NewOps)
2763 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2764 // Re-generate the operands list.
2766 if (AccumulatedConstant != 0)
2767 Ops.push_back(getConstant(AccumulatedConstant));
2768 for (auto &MulOp : MulOpLists) {
2769 if (MulOp.first == 1) {
2770 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2771 } else if (MulOp.first != 0) {
2772 Ops.push_back(getMulExpr(
2773 getConstant(MulOp.first),
2774 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2775 SCEV::FlagAnyWrap, Depth + 1));
2780 if (Ops.size() == 1)
2782 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2786 // If we are adding something to a multiply expression, make sure the
2787 // something is not already an operand of the multiply. If so, merge it into
2789 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2790 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2791 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2792 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2793 if (isa<SCEVConstant>(MulOpSCEV))
2795 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2796 if (MulOpSCEV == Ops[AddOp]) {
2797 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2798 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2799 if (Mul->getNumOperands() != 2) {
2800 // If the multiply has more than two operands, we must get the
2802 SmallVector<const SCEV *, 4> MulOps(
2803 Mul->operands().take_front(MulOp));
2804 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2805 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2807 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2808 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2809 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2810 SCEV::FlagAnyWrap, Depth + 1);
2811 if (Ops.size() == 2) return OuterMul;
2813 Ops.erase(Ops.begin()+AddOp);
2814 Ops.erase(Ops.begin()+Idx-1);
2816 Ops.erase(Ops.begin()+Idx);
2817 Ops.erase(Ops.begin()+AddOp-1);
2819 Ops.push_back(OuterMul);
2820 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2823 // Check this multiply against other multiplies being added together.
2824 for (unsigned OtherMulIdx = Idx+1;
2825 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2827 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2828 // If MulOp occurs in OtherMul, we can fold the two multiplies
2830 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2831 OMulOp != e; ++OMulOp)
2832 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2833 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2834 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2835 if (Mul->getNumOperands() != 2) {
2836 SmallVector<const SCEV *, 4> MulOps(
2837 Mul->operands().take_front(MulOp));
2838 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2839 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2841 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2842 if (OtherMul->getNumOperands() != 2) {
2843 SmallVector<const SCEV *, 4> MulOps(
2844 OtherMul->operands().take_front(OMulOp));
2845 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2846 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2848 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2849 const SCEV *InnerMulSum =
2850 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2851 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2852 SCEV::FlagAnyWrap, Depth + 1);
2853 if (Ops.size() == 2) return OuterMul;
2854 Ops.erase(Ops.begin()+Idx);
2855 Ops.erase(Ops.begin()+OtherMulIdx-1);
2856 Ops.push_back(OuterMul);
2857 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2863 // If there are any add recurrences in the operands list, see if any other
2864 // added values are loop invariant. If so, we can fold them into the
2866 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2869 // Scan over all recurrences, trying to fold loop invariants into them.
2870 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2871 // Scan all of the other operands to this add and add them to the vector if
2872 // they are loop invariant w.r.t. the recurrence.
2873 SmallVector<const SCEV *, 8> LIOps;
2874 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2875 const Loop *AddRecLoop = AddRec->getLoop();
2876 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2877 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2878 LIOps.push_back(Ops[i]);
2879 Ops.erase(Ops.begin()+i);
2883 // If we found some loop invariants, fold them into the recurrence.
2884 if (!LIOps.empty()) {
2885 // Compute nowrap flags for the addition of the loop-invariant ops and
2886 // the addrec. Temporarily push it as an operand for that purpose. These
2887 // flags are valid in the scope of the addrec only.
2888 LIOps.push_back(AddRec);
2889 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2892 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2893 LIOps.push_back(AddRec->getStart());
2895 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2897 // It is not in general safe to propagate flags valid on an add within
2898 // the addrec scope to one outside it. We must prove that the inner
2899 // scope is guaranteed to execute if the outer one does to be able to
2900 // safely propagate. We know the program is undefined if poison is
2901 // produced on the inner scoped addrec. We also know that *for this use*
2902 // the outer scoped add can't overflow (because of the flags we just
2903 // computed for the inner scoped add) without the program being undefined.
2904 // Proving that entry to the outer scope neccesitates entry to the inner
2905 // scope, thus proves the program undefined if the flags would be violated
2906 // in the outer scope.
2907 SCEV::NoWrapFlags AddFlags = Flags;
2908 if (AddFlags != SCEV::FlagAnyWrap) {
2909 auto *DefI = getDefiningScopeBound(LIOps);
2910 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2911 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2912 AddFlags = SCEV::FlagAnyWrap;
2914 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2916 // Build the new addrec. Propagate the NUW and NSW flags if both the
2917 // outer add and the inner addrec are guaranteed to have no overflow.
2918 // Always propagate NW.
2919 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2920 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2922 // If all of the other operands were loop invariant, we are done.
2923 if (Ops.size() == 1) return NewRec;
2925 // Otherwise, add the folded AddRec by the non-invariant parts.
2926 for (unsigned i = 0;; ++i)
2927 if (Ops[i] == AddRec) {
2931 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2934 // Okay, if there weren't any loop invariants to be folded, check to see if
2935 // there are multiple AddRec's with the same loop induction variable being
2936 // added together. If so, we can fold them.
2937 for (unsigned OtherIdx = Idx+1;
2938 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2940 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2941 // so that the 1st found AddRecExpr is dominated by all others.
2942 assert(DT.dominates(
2943 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2944 AddRec->getLoop()->getHeader()) &&
2945 "AddRecExprs are not sorted in reverse dominance order?");
2946 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2947 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2948 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2949 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2951 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 if (OtherAddRec->getLoop() == AddRecLoop) {
2953 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2955 if (i >= AddRecOps.size()) {
2956 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2959 SmallVector<const SCEV *, 2> TwoOps = {
2960 AddRecOps[i], OtherAddRec->getOperand(i)};
2961 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2963 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2966 // Step size has changed, so we cannot guarantee no self-wraparound.
2967 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2968 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2972 // Otherwise couldn't fold anything into this recurrence. Move onto the
2976 // Okay, it looks like we really DO need an add expr. Check to see if we
2977 // already have one, otherwise create a new one.
2978 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2982 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2983 SCEV::NoWrapFlags Flags) {
2984 FoldingSetNodeID ID;
2985 ID.AddInteger(scAddExpr);
2986 for (const SCEV *Op : Ops)
2990 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2992 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2993 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2994 S = new (SCEVAllocator)
2995 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2996 UniqueSCEVs.InsertNode(S, IP);
2997 registerUser(S, Ops);
2999 S->setNoWrapFlags(Flags);
3004 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3005 const Loop *L, SCEV::NoWrapFlags Flags) {
3006 FoldingSetNodeID ID;
3007 ID.AddInteger(scAddRecExpr);
3008 for (const SCEV *Op : Ops)
3013 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3015 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3016 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3017 S = new (SCEVAllocator)
3018 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3019 UniqueSCEVs.InsertNode(S, IP);
3020 LoopUsers[L].push_back(S);
3021 registerUser(S, Ops);
3023 setNoWrapFlags(S, Flags);
3028 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3029 SCEV::NoWrapFlags Flags) {
3030 FoldingSetNodeID ID;
3031 ID.AddInteger(scMulExpr);
3032 for (const SCEV *Op : Ops)
3036 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3038 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3039 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3040 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3042 UniqueSCEVs.InsertNode(S, IP);
3043 registerUser(S, Ops);
3045 S->setNoWrapFlags(Flags);
3049 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3051 if (j > 1 && k / j != i) Overflow = true;
3055 /// Compute the result of "n choose k", the binomial coefficient. If an
3056 /// intermediate computation overflows, Overflow will be set and the return will
3057 /// be garbage. Overflow is not cleared on absence of overflow.
3058 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3059 // We use the multiplicative formula:
3060 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3061 // At each iteration, we take the n-th term of the numeral and divide by the
3062 // (k-n)th term of the denominator. This division will always produce an
3063 // integral result, and helps reduce the chance of overflow in the
3064 // intermediate computations. However, we can still overflow even when the
3065 // final result would fit.
3067 if (n == 0 || n == k) return 1;
3068 if (k > n) return 0;
3074 for (uint64_t i = 1; i <= k; ++i) {
3075 r = umul_ov(r, n-(i-1), Overflow);
3081 /// Determine if any of the operands in this SCEV are a constant or if
3082 /// any of the add or multiply expressions in this SCEV contain a constant.
3083 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3084 struct FindConstantInAddMulChain {
3085 bool FoundConstant = false;
3087 bool follow(const SCEV *S) {
3088 FoundConstant |= isa<SCEVConstant>(S);
3089 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3092 bool isDone() const {
3093 return FoundConstant;
3097 FindConstantInAddMulChain F;
3098 SCEVTraversal<FindConstantInAddMulChain> ST(F);
3099 ST.visitAll(StartExpr);
3100 return F.FoundConstant;
3103 /// Get a canonical multiply expression, or something simpler if possible.
3104 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
3105 SCEV::NoWrapFlags OrigFlags,
3107 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3108 "only nuw or nsw allowed");
3109 assert(!Ops.empty() && "Cannot get empty mul!");
3110 if (Ops.size() == 1) return Ops[0];
3112 Type *ETy = Ops[0]->getType();
3113 assert(!ETy->isPointerTy());
3114 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3115 assert(Ops[i]->getType() == ETy &&
3116 "SCEVMulExpr operand types don't match!");
3119 // Sort by complexity, this groups all similar expression types together.
3120 GroupByComplexity(Ops, &LI, DT);
3122 // If there are any constants, fold them together.
3124 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3126 assert(Idx < Ops.size());
3127 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3128 // We found two constants, fold them together!
3129 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3130 if (Ops.size() == 2) return Ops[0];
3131 Ops.erase(Ops.begin()+1); // Erase the folded element
3132 LHSC = cast<SCEVConstant>(Ops[0]);
3135 // If we have a multiply of zero, it will always be zero.
3136 if (LHSC->getValue()->isZero())
3139 // If we are left with a constant one being multiplied, strip it off.
3140 if (LHSC->getValue()->isOne()) {
3141 Ops.erase(Ops.begin());
3145 if (Ops.size() == 1)
3149 // Delay expensive flag strengthening until necessary.
3150 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3151 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3154 // Limit recursion calls depth.
3155 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3156 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3158 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3159 // Don't strengthen flags if we have no new information.
3160 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3161 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3162 Mul->setNoWrapFlags(ComputeFlags(Ops));
3166 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3167 if (Ops.size() == 2) {
3168 // C1*(C2+V) -> C1*C2 + C1*V
3169 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3170 // If any of Add's ops are Adds or Muls with a constant, apply this
3171 // transformation as well.
3173 // TODO: There are some cases where this transformation is not
3174 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3175 // this transformation should be narrowed down.
3176 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3177 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3178 SCEV::FlagAnyWrap, Depth + 1);
3179 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3180 SCEV::FlagAnyWrap, Depth + 1);
3181 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3184 if (Ops[0]->isAllOnesValue()) {
3185 // If we have a mul by -1 of an add, try distributing the -1 among the
3187 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3188 SmallVector<const SCEV *, 4> NewOps;
3189 bool AnyFolded = false;
3190 for (const SCEV *AddOp : Add->operands()) {
3191 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3193 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3194 NewOps.push_back(Mul);
3197 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3198 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3199 // Negation preserves a recurrence's no self-wrap property.
3200 SmallVector<const SCEV *, 4> Operands;
3201 for (const SCEV *AddRecOp : AddRec->operands())
3202 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3204 // Let M be the minimum representable signed value. AddRec with nsw
3205 // multiplied by -1 can have signed overflow if and only if it takes a
3206 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3207 // maximum signed value. In all other cases signed overflow is
3209 auto FlagsMask = SCEV::FlagNW;
3210 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3212 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3213 if (getSignedRangeMin(AddRec) != MinInt)
3214 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3216 return getAddRecExpr(Operands, AddRec->getLoop(),
3217 AddRec->getNoWrapFlags(FlagsMask));
3223 // Skip over the add expression until we get to a multiply.
3224 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3227 // If there are mul operands inline them all into this expression.
3228 if (Idx < Ops.size()) {
3229 bool DeletedMul = false;
3230 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3231 if (Ops.size() > MulOpsInlineThreshold)
3233 // If we have an mul, expand the mul operands onto the end of the
3235 Ops.erase(Ops.begin()+Idx);
3236 append_range(Ops, Mul->operands());
3240 // If we deleted at least one mul, we added operands to the end of the
3241 // list, and they are not necessarily sorted. Recurse to resort and
3242 // resimplify any operands we just acquired.
3244 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3247 // If there are any add recurrences in the operands list, see if any other
3248 // added values are loop invariant. If so, we can fold them into the
3250 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3253 // Scan over all recurrences, trying to fold loop invariants into them.
3254 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3255 // Scan all of the other operands to this mul and add them to the vector
3256 // if they are loop invariant w.r.t. the recurrence.
3257 SmallVector<const SCEV *, 8> LIOps;
3258 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3259 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3260 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3261 LIOps.push_back(Ops[i]);
3262 Ops.erase(Ops.begin()+i);
3266 // If we found some loop invariants, fold them into the recurrence.
3267 if (!LIOps.empty()) {
3268 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3269 SmallVector<const SCEV *, 4> NewOps;
3270 NewOps.reserve(AddRec->getNumOperands());
3271 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3272 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3273 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3274 SCEV::FlagAnyWrap, Depth + 1));
3276 // Build the new addrec. Propagate the NUW and NSW flags if both the
3277 // outer mul and the inner addrec are guaranteed to have no overflow.
3279 // No self-wrap cannot be guaranteed after changing the step size, but
3280 // will be inferred if either NUW or NSW is true.
3281 SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3282 const SCEV *NewRec = getAddRecExpr(
3283 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(Flags));
3285 // If all of the other operands were loop invariant, we are done.
3286 if (Ops.size() == 1) return NewRec;
3288 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3289 for (unsigned i = 0;; ++i)
3290 if (Ops[i] == AddRec) {
3294 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3297 // Okay, if there weren't any loop invariants to be folded, check to see
3298 // if there are multiple AddRec's with the same loop induction variable
3299 // being multiplied together. If so, we can fold them.
3301 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3302 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3303 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3304 // ]]],+,...up to x=2n}.
3305 // Note that the arguments to choose() are always integers with values
3306 // known at compile time, never SCEV objects.
3308 // The implementation avoids pointless extra computations when the two
3309 // addrec's are of different length (mathematically, it's equivalent to
3310 // an infinite stream of zeros on the right).
3311 bool OpsModified = false;
3312 for (unsigned OtherIdx = Idx+1;
3313 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3315 const SCEVAddRecExpr *OtherAddRec =
3316 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3317 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3320 // Limit max number of arguments to avoid creation of unreasonably big
3321 // SCEVAddRecs with very complex operands.
3322 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3323 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3326 bool Overflow = false;
3327 Type *Ty = AddRec->getType();
3328 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3329 SmallVector<const SCEV*, 7> AddRecOps;
3330 for (int x = 0, xe = AddRec->getNumOperands() +
3331 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3332 SmallVector <const SCEV *, 7> SumOps;
3333 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3334 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3335 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3336 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3337 z < ze && !Overflow; ++z) {
3338 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3340 if (LargerThan64Bits)
3341 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3343 Coeff = Coeff1*Coeff2;
3344 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3345 const SCEV *Term1 = AddRec->getOperand(y-z);
3346 const SCEV *Term2 = OtherAddRec->getOperand(z);
3347 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3348 SCEV::FlagAnyWrap, Depth + 1));
3352 SumOps.push_back(getZero(Ty));
3353 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3356 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3358 if (Ops.size() == 2) return NewAddRec;
3359 Ops[Idx] = NewAddRec;
3360 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3362 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3368 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3370 // Otherwise couldn't fold anything into this recurrence. Move onto the
3374 // Okay, it looks like we really DO need an mul expr. Check to see if we
3375 // already have one, otherwise create a new one.
3376 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3379 /// Represents an unsigned remainder expression based on unsigned division.
3380 const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
3382 assert(getEffectiveSCEVType(LHS->getType()) ==
3383 getEffectiveSCEVType(RHS->getType()) &&
3384 "SCEVURemExpr operand types don't match!");
3386 // Short-circuit easy cases
3387 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3388 // If constant is one, the result is trivial
3389 if (RHSC->getValue()->isOne())
3390 return getZero(LHS->getType()); // X urem 1 --> 0
3392 // If constant is a power of two, fold into a zext(trunc(LHS)).
3393 if (RHSC->getAPInt().isPowerOf2()) {
3394 Type *FullTy = LHS->getType();
3396 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3397 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3401 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3402 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3403 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3404 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3407 /// Get a canonical unsigned division expression, or something simpler if
3409 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
3411 assert(!LHS->getType()->isPointerTy() &&
3412 "SCEVUDivExpr operand can't be pointer!");
3413 assert(LHS->getType() == RHS->getType() &&
3414 "SCEVUDivExpr operand types don't match!");
3416 FoldingSetNodeID ID;
3417 ID.AddInteger(scUDivExpr);
3421 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3425 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3426 if (LHSC->getValue()->isZero())
3429 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3430 if (RHSC->getValue()->isOne())
3431 return LHS; // X udiv 1 --> x
3432 // If the denominator is zero, the result of the udiv is undefined. Don't
3433 // try to analyze it, because the resolution chosen here may differ from
3434 // the resolution chosen in other parts of the compiler.
3435 if (!RHSC->getValue()->isZero()) {
3436 // Determine if the division can be folded into the operands of
3438 // TODO: Generalize this to non-constants by using known-bits information.
3439 Type *Ty = LHS->getType();
3440 unsigned LZ = RHSC->getAPInt().countl_zero();
3441 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3442 // For non-power-of-two values, effectively round the value up to the
3443 // nearest power of two.
3444 if (!RHSC->getAPInt().isPowerOf2())
3446 IntegerType *ExtTy =
3447 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3448 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3449 if (const SCEVConstant *Step =
3450 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3451 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3452 const APInt &StepInt = Step->getAPInt();
3453 const APInt &DivInt = RHSC->getAPInt();
3454 if (!StepInt.urem(DivInt) &&
3455 getZeroExtendExpr(AR, ExtTy) ==
3456 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3457 getZeroExtendExpr(Step, ExtTy),
3458 AR->getLoop(), SCEV::FlagAnyWrap)) {
3459 SmallVector<const SCEV *, 4> Operands;
3460 for (const SCEV *Op : AR->operands())
3461 Operands.push_back(getUDivExpr(Op, RHS));
3462 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3464 /// Get a canonical UDivExpr for a recurrence.
3465 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3466 // We can currently only fold X%N if X is constant.
3467 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3468 if (StartC && !DivInt.urem(StepInt) &&
3469 getZeroExtendExpr(AR, ExtTy) ==
3470 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3471 getZeroExtendExpr(Step, ExtTy),
3472 AR->getLoop(), SCEV::FlagAnyWrap)) {
3473 const APInt &StartInt = StartC->getAPInt();
3474 const APInt &StartRem = StartInt.urem(StepInt);
3475 if (StartRem != 0) {
3476 const SCEV *NewLHS =
3477 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3478 AR->getLoop(), SCEV::FlagNW);
3479 if (LHS != NewLHS) {
3482 // Reset the ID to include the new LHS, and check if it is
3485 ID.AddInteger(scUDivExpr);
3489 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3495 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3496 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3497 SmallVector<const SCEV *, 4> Operands;
3498 for (const SCEV *Op : M->operands())
3499 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3500 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3501 // Find an operand that's safely divisible.
3502 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3503 const SCEV *Op = M->getOperand(i);
3504 const SCEV *Div = getUDivExpr(Op, RHSC);
3505 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3506 Operands = SmallVector<const SCEV *, 4>(M->operands());
3508 return getMulExpr(Operands);
3513 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3514 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3515 if (auto *DivisorConstant =
3516 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3517 bool Overflow = false;
3519 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3521 return getConstant(RHSC->getType(), 0, false);
3523 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3527 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3528 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3529 SmallVector<const SCEV *, 4> Operands;
3530 for (const SCEV *Op : A->operands())
3531 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3532 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3534 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3535 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3536 if (isa<SCEVUDivExpr>(Op) ||
3537 getMulExpr(Op, RHS) != A->getOperand(i))
3539 Operands.push_back(Op);
3541 if (Operands.size() == A->getNumOperands())
3542 return getAddExpr(Operands);
3546 // Fold if both operands are constant.
3547 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3548 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3552 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3553 // changes). Make sure we get a new one.
3555 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3556 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3558 UniqueSCEVs.InsertNode(S, IP);
3559 registerUser(S, {LHS, RHS});
3563 APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3564 APInt A = C1->getAPInt().abs();
3565 APInt B = C2->getAPInt().abs();
3566 uint32_t ABW = A.getBitWidth();
3567 uint32_t BBW = B.getBitWidth();
3574 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3577 /// Get a canonical unsigned division expression, or something simpler if
3578 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3579 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3580 /// it's not exact because the udiv may be clearing bits.
3581 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
3583 // TODO: we could try to find factors in all sorts of things, but for now we
3584 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3585 // end of this file for inspiration.
3587 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3588 if (!Mul || !Mul->hasNoUnsignedWrap())
3589 return getUDivExpr(LHS, RHS);
3591 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3592 // If the mulexpr multiplies by a constant, then that constant must be the
3593 // first element of the mulexpr.
3594 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3595 if (LHSCst == RHSCst) {
3596 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3597 return getMulExpr(Operands);
3600 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3601 // that there's a factor provided by one of the other terms. We need to
3603 APInt Factor = gcd(LHSCst, RHSCst);
3604 if (!Factor.isIntN(1)) {
3606 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3608 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3609 SmallVector<const SCEV *, 2> Operands;
3610 Operands.push_back(LHSCst);
3611 append_range(Operands, Mul->operands().drop_front());
3612 LHS = getMulExpr(Operands);
3614 Mul = dyn_cast<SCEVMulExpr>(LHS);
3616 return getUDivExactExpr(LHS, RHS);
3621 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3622 if (Mul->getOperand(i) == RHS) {
3623 SmallVector<const SCEV *, 2> Operands;
3624 append_range(Operands, Mul->operands().take_front(i));
3625 append_range(Operands, Mul->operands().drop_front(i + 1));
3626 return getMulExpr(Operands);
3630 return getUDivExpr(LHS, RHS);
3633 /// Get an add recurrence expression for the specified loop. Simplify the
3634 /// expression as much as possible.
3635 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3637 SCEV::NoWrapFlags Flags) {
3638 SmallVector<const SCEV *, 4> Operands;
3639 Operands.push_back(Start);
3640 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3641 if (StepChrec->getLoop() == L) {
3642 append_range(Operands, StepChrec->operands());
3643 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3646 Operands.push_back(Step);
3647 return getAddRecExpr(Operands, L, Flags);
3650 /// Get an add recurrence expression for the specified loop. Simplify the
3651 /// expression as much as possible.
3653 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
3654 const Loop *L, SCEV::NoWrapFlags Flags) {
3655 if (Operands.size() == 1) return Operands[0];
3657 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3658 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3659 assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
3660 "SCEVAddRecExpr operand types don't match!");
3661 assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3663 for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3664 assert(isLoopInvariant(Operands[i], L) &&
3665 "SCEVAddRecExpr operand is not loop-invariant!");
3668 if (Operands.back()->isZero()) {
3669 Operands.pop_back();
3670 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3673 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3674 // use that information to infer NUW and NSW flags. However, computing a
3675 // BE count requires calling getAddRecExpr, so we may not yet have a
3676 // meaningful BE count at this point (and if we don't, we'd be stuck
3677 // with a SCEVCouldNotCompute as the cached BE count).
3679 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3681 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3682 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3683 const Loop *NestedLoop = NestedAR->getLoop();
3684 if (L->contains(NestedLoop)
3685 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3686 : (!NestedLoop->contains(L) &&
3687 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3688 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3689 Operands[0] = NestedAR->getStart();
3690 // AddRecs require their operands be loop-invariant with respect to their
3691 // loops. Don't perform this transformation if it would break this
3693 bool AllInvariant = all_of(
3694 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3697 // Create a recurrence for the outer loop with the same step size.
3699 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3700 // inner recurrence has the same property.
3701 SCEV::NoWrapFlags OuterFlags =
3702 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3704 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3705 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3706 return isLoopInvariant(Op, NestedLoop);
3710 // Ok, both add recurrences are valid after the transformation.
3712 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3713 // the outer recurrence has the same property.
3714 SCEV::NoWrapFlags InnerFlags =
3715 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3716 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3719 // Reset Operands to its original state.
3720 Operands[0] = NestedAR;
3724 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3725 // already have one, otherwise create a new one.
3726 return getOrCreateAddRecExpr(Operands, L, Flags);
3730 ScalarEvolution::getGEPExpr(GEPOperator *GEP,
3731 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3732 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3733 // getSCEV(Base)->getType() has the same address space as Base->getType()
3734 // because SCEV::getType() preserves the address space.
3735 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3736 const bool AssumeInBoundsFlags = [&]() {
3737 if (!GEP->isInBounds())
3740 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3741 // but to do that, we have to ensure that said flag is valid in the entire
3742 // defined scope of the SCEV.
3743 auto *GEPI = dyn_cast<Instruction>(GEP);
3744 // TODO: non-instructions have global scope. We might be able to prove
3745 // some global scope cases
3746 return GEPI && isSCEVExprNeverPoison(GEPI);
3749 SCEV::NoWrapFlags OffsetWrap =
3750 AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3752 Type *CurTy = GEP->getType();
3753 bool FirstIter = true;
3754 SmallVector<const SCEV *, 4> Offsets;
3755 for (const SCEV *IndexExpr : IndexExprs) {
3756 // Compute the (potentially symbolic) offset in bytes for this index.
3757 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3758 // For a struct, add the member offset.
3759 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3760 unsigned FieldNo = Index->getZExtValue();
3761 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3762 Offsets.push_back(FieldOffset);
3764 // Update CurTy to the type of the field at Index.
3765 CurTy = STy->getTypeAtIndex(Index);
3767 // Update CurTy to its element type.
3769 assert(isa<PointerType>(CurTy) &&
3770 "The first index of a GEP indexes a pointer");
3771 CurTy = GEP->getSourceElementType();
3774 CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3776 // For an array, add the element offset, explicitly scaled.
3777 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3778 // Getelementptr indices are signed.
3779 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3781 // Multiply the index by the element size to compute the element offset.
3782 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3783 Offsets.push_back(LocalOffset);
3787 // Handle degenerate case of GEP without offsets.
3788 if (Offsets.empty())
3791 // Add the offsets together, assuming nsw if inbounds.
3792 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3793 // Add the base address and the offset. We cannot use the nsw flag, as the
3794 // base address is unsigned. However, if we know that the offset is
3795 // non-negative, we can use nuw.
3796 SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3797 ? SCEV::FlagNUW : SCEV::FlagAnyWrap;
3798 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3799 assert(BaseExpr->getType() == GEPExpr->getType() &&
3800 "GEP should not change type mid-flight.");
3804 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3805 ArrayRef<const SCEV *> Ops) {
3806 FoldingSetNodeID ID;
3807 ID.AddInteger(SCEVType);
3808 for (const SCEV *Op : Ops)
3811 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3814 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3815 SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3816 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3819 const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
3820 SmallVectorImpl<const SCEV *> &Ops) {
3821 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3822 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3823 if (Ops.size() == 1) return Ops[0];
3825 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3826 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3827 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3828 "Operand types don't match!");
3829 assert(Ops[0]->getType()->isPointerTy() ==
3830 Ops[i]->getType()->isPointerTy() &&
3831 "min/max should be consistently pointerish");
3835 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3836 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3838 // Sort by complexity, this groups all similar expression types together.
3839 GroupByComplexity(Ops, &LI, DT);
3841 // Check if we have created the same expression before.
3842 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3846 // If there are any constants, fold them together.
3848 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3850 assert(Idx < Ops.size());
3851 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3854 return APIntOps::smax(LHS, RHS);
3856 return APIntOps::smin(LHS, RHS);
3858 return APIntOps::umax(LHS, RHS);
3860 return APIntOps::umin(LHS, RHS);
3862 llvm_unreachable("Unknown SCEV min/max opcode");
3866 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3867 // We found two constants, fold them together!
3868 ConstantInt *Fold = ConstantInt::get(
3869 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3870 Ops[0] = getConstant(Fold);
3871 Ops.erase(Ops.begin()+1); // Erase the folded element
3872 if (Ops.size() == 1) return Ops[0];
3873 LHSC = cast<SCEVConstant>(Ops[0]);
3876 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3877 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3879 if (IsMax ? IsMinV : IsMaxV) {
3880 // If we are left with a constant minimum(/maximum)-int, strip it off.
3881 Ops.erase(Ops.begin());
3883 } else if (IsMax ? IsMaxV : IsMinV) {
3884 // If we have a max(/min) with a constant maximum(/minimum)-int,
3885 // it will always be the extremum.
3889 if (Ops.size() == 1) return Ops[0];
3892 // Find the first operation of the same kind
3893 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3896 // Check to see if one of the operands is of the same kind. If so, expand its
3897 // operands onto our operand list, and recurse to simplify.
3898 if (Idx < Ops.size()) {
3899 bool DeletedAny = false;
3900 while (Ops[Idx]->getSCEVType() == Kind) {
3901 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3902 Ops.erase(Ops.begin()+Idx);
3903 append_range(Ops, SMME->operands());
3908 return getMinMaxExpr(Kind, Ops);
3911 // Okay, check to see if the same value occurs in the operand list twice. If
3912 // so, delete one. Since we sorted the list, these values are required to
3914 llvm::CmpInst::Predicate GEPred =
3915 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
3916 llvm::CmpInst::Predicate LEPred =
3917 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
3918 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3919 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3920 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3921 if (Ops[i] == Ops[i + 1] ||
3922 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3923 // X op Y op Y --> X op Y
3924 // X op Y --> X, if we know X, Y are ordered appropriately
3925 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3928 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3930 // X op Y --> Y, if we know X, Y are ordered appropriately
3931 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3937 if (Ops.size() == 1) return Ops[0];
3939 assert(!Ops.empty() && "Reduced smax down to nothing!");
3941 // Okay, it looks like we really DO need an expr. Check to see if we
3942 // already have one, otherwise create a new one.
3943 FoldingSetNodeID ID;
3944 ID.AddInteger(Kind);
3945 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3946 ID.AddPointer(Ops[i]);
3948 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3950 return ExistingSCEV;
3951 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3952 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3953 SCEV *S = new (SCEVAllocator)
3954 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3956 UniqueSCEVs.InsertNode(S, IP);
3957 registerUser(S, Ops);
3963 class SCEVSequentialMinMaxDeduplicatingVisitor final
3964 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3965 std::optional<const SCEV *>> {
3966 using RetVal = std::optional<const SCEV *>;
3967 using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
3969 ScalarEvolution &SE;
3970 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3971 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3972 SmallPtrSet<const SCEV *, 16> SeenOps;
3974 bool canRecurseInto(SCEVTypes Kind) const {
3975 // We can only recurse into the SCEV expression of the same effective type
3976 // as the type of our root SCEV expression.
3977 return RootKind == Kind || NonSequentialRootKind == Kind;
3980 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3981 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3982 "Only for min/max expressions.");
3983 SCEVTypes Kind = S->getSCEVType();
3985 if (!canRecurseInto(Kind))
3988 auto *NAry = cast<SCEVNAryExpr>(S);
3989 SmallVector<const SCEV *> NewOps;
3990 bool Changed = visit(Kind, NAry->operands(), NewOps);
3995 return std::nullopt;
3997 return isa<SCEVSequentialMinMaxExpr>(S)
3998 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3999 : SE.getMinMaxExpr(Kind, NewOps);
4002 RetVal visit(const SCEV *S) {
4003 // Has the whole operand been seen already?
4004 if (!SeenOps.insert(S).second)
4005 return std::nullopt;
4006 return Base::visit(S);
4010 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4012 : SE(SE), RootKind(RootKind),
4013 NonSequentialRootKind(
4014 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4017 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4018 SmallVectorImpl<const SCEV *> &NewOps) {
4019 bool Changed = false;
4020 SmallVector<const SCEV *> Ops;
4021 Ops.reserve(OrigOps.size());
4023 for (const SCEV *Op : OrigOps) {
4024 RetVal NewOp = visit(Op);
4028 Ops.emplace_back(*NewOp);
4032 NewOps = std::move(Ops);
4036 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4038 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4040 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4042 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4044 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4046 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4048 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4050 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4052 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4054 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4056 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4057 return visitAnyMinMaxExpr(Expr);
4060 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4061 return visitAnyMinMaxExpr(Expr);
4064 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4065 return visitAnyMinMaxExpr(Expr);
4068 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4069 return visitAnyMinMaxExpr(Expr);
4072 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4073 return visitAnyMinMaxExpr(Expr);
4076 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4078 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4083 static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) {
4100 // If any operand is poison, the whole expression is poison.
4102 case scSequentialUMinExpr:
4103 // FIXME: if the *first* operand is poison, the whole expression is poison.
4104 return false; // Pessimistically, say that it does not propagate poison.
4105 case scCouldNotCompute:
4106 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4108 llvm_unreachable("Unknown SCEV kind!");
4111 /// Return true if V is poison given that AssumedPoison is already poison.
4112 static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4113 // The only way poison may be introduced in a SCEV expression is from a
4114 // poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4115 // not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4116 // introduce poison -- they encode guaranteed, non-speculated knowledge.
4118 // Additionally, all SCEV nodes propagate poison from inputs to outputs,
4119 // with the notable exception of umin_seq, where only poison from the first
4120 // operand is (unconditionally) propagated.
4121 struct SCEVPoisonCollector {
4122 bool LookThroughMaybePoisonBlocking;
4123 SmallPtrSet<const SCEV *, 4> MaybePoison;
4124 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4125 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4127 bool follow(const SCEV *S) {
4128 if (!LookThroughMaybePoisonBlocking &&
4129 !scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType()))
4132 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4133 if (!isGuaranteedNotToBePoison(SU->getValue()))
4134 MaybePoison.insert(S);
4138 bool isDone() const { return false; }
4141 // First collect all SCEVs that might result in AssumedPoison to be poison.
4142 // We need to look through potentially poison-blocking operations here,
4143 // because we want to find all SCEVs that *might* result in poison, not only
4144 // those that are *required* to.
4145 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4146 visitAll(AssumedPoison, PC1);
4148 // AssumedPoison is never poison. As the assumption is false, the implication
4149 // is true. Don't bother walking the other SCEV in this case.
4150 if (PC1.MaybePoison.empty())
4153 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4154 // as well. We cannot look through potentially poison-blocking operations
4155 // here, as their arguments only *may* make the result poison.
4156 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4159 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4160 // it will also make S poison by being part of PC2.MaybePoison.
4161 return all_of(PC1.MaybePoison,
4162 [&](const SCEV *S) { return PC2.MaybePoison.contains(S); });
4166 ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
4167 SmallVectorImpl<const SCEV *> &Ops) {
4168 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4169 "Not a SCEVSequentialMinMaxExpr!");
4170 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4171 if (Ops.size() == 1)
4174 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4175 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4176 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4177 "Operand types don't match!");
4178 assert(Ops[0]->getType()->isPointerTy() ==
4179 Ops[i]->getType()->isPointerTy() &&
4180 "min/max should be consistently pointerish");
4184 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4185 // so we can *NOT* do any kind of sorting of the expressions!
4187 // Check if we have created the same expression before.
4188 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4191 // FIXME: there are *some* simplifications that we can do here.
4193 // Keep only the first instance of an operand.
4195 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4196 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4198 return getSequentialMinMaxExpr(Kind, Ops);
4201 // Check to see if one of the operands is of the same kind. If so, expand its
4202 // operands onto our operand list, and recurse to simplify.
4205 bool DeletedAny = false;
4206 while (Idx < Ops.size()) {
4207 if (Ops[Idx]->getSCEVType() != Kind) {
4211 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4212 Ops.erase(Ops.begin() + Idx);
4213 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4214 SMME->operands().end());
4219 return getSequentialMinMaxExpr(Kind, Ops);
4222 const SCEV *SaturationPoint;
4223 ICmpInst::Predicate Pred;
4225 case scSequentialUMinExpr:
4226 SaturationPoint = getZero(Ops[0]->getType());
4227 Pred = ICmpInst::ICMP_ULE;
4230 llvm_unreachable("Not a sequential min/max type.");
4233 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4234 // We can replace %x umin_seq %y with %x umin %y if either:
4235 // * %y being poison implies %x is also poison.
4236 // * %x cannot be the saturating value (e.g. zero for umin).
4237 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4238 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4240 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4241 Ops[i - 1] = getMinMaxExpr(
4242 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
4244 Ops.erase(Ops.begin() + i);
4245 return getSequentialMinMaxExpr(Kind, Ops);
4247 // Fold %x umin_seq %y to %x if %x ule %y.
4248 // TODO: We might be able to prove the predicate for a later operand.
4249 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4250 Ops.erase(Ops.begin() + i);
4251 return getSequentialMinMaxExpr(Kind, Ops);
4255 // Okay, it looks like we really DO need an expr. Check to see if we
4256 // already have one, otherwise create a new one.
4257 FoldingSetNodeID ID;
4258 ID.AddInteger(Kind);
4259 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4260 ID.AddPointer(Ops[i]);
4262 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4264 return ExistingSCEV;
4266 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4267 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4268 SCEV *S = new (SCEVAllocator)
4269 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4271 UniqueSCEVs.InsertNode(S, IP);
4272 registerUser(S, Ops);
4276 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4277 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4278 return getSMaxExpr(Ops);
4281 const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4282 return getMinMaxExpr(scSMaxExpr, Ops);
4285 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4286 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4287 return getUMaxExpr(Ops);
4290 const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4291 return getMinMaxExpr(scUMaxExpr, Ops);
4294 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
4296 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4297 return getSMinExpr(Ops);
4300 const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
4301 return getMinMaxExpr(scSMinExpr, Ops);
4304 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4306 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4307 return getUMinExpr(Ops, Sequential);
4310 const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
4312 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4313 : getMinMaxExpr(scUMinExpr, Ops);
4317 ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) {
4318 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4319 if (Size.isScalable())
4320 Res = getMulExpr(Res, getVScale(IntTy));
4324 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
4325 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4328 const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) {
4329 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4332 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
4335 // We can bypass creating a target-independent constant expression and then
4336 // folding it back into a ConstantInt. This is just a compile-time
4338 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4339 assert(!SL->getSizeInBits().isScalable() &&
4340 "Cannot get offset for structure containing scalable vector types");
4341 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4344 const SCEV *ScalarEvolution::getUnknown(Value *V) {
4345 // Don't attempt to do anything other than create a SCEVUnknown object
4346 // here. createSCEV only calls getUnknown after checking for all other
4347 // interesting possibilities, and any other code that calls getUnknown
4348 // is doing so in order to hide a value from SCEV canonicalization.
4350 FoldingSetNodeID ID;
4351 ID.AddInteger(scUnknown);
4354 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4355 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4356 "Stale SCEVUnknown in uniquing map!");
4359 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4361 FirstUnknown = cast<SCEVUnknown>(S);
4362 UniqueSCEVs.InsertNode(S, IP);
4366 //===----------------------------------------------------------------------===//
4367 // Basic SCEV Analysis and PHI Idiom Recognition Code
4370 /// Test if values of the given type are analyzable within the SCEV
4371 /// framework. This primarily includes integer types, and it can optionally
4372 /// include pointer types if the ScalarEvolution class has access to
4373 /// target-specific information.
4374 bool ScalarEvolution::isSCEVable(Type *Ty) const {
4375 // Integers and pointers are always SCEVable.
4376 return Ty->isIntOrPtrTy();
4379 /// Return the size in bits of the specified type, for which isSCEVable must
4381 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
4382 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4383 if (Ty->isPointerTy())
4384 return getDataLayout().getIndexTypeSizeInBits(Ty);
4385 return getDataLayout().getTypeSizeInBits(Ty);
4388 /// Return a type with the same bitwidth as the given type and which represents
4389 /// how SCEV will treat the given type, for which isSCEVable must return
4390 /// true. For pointer types, this is the pointer index sized integer type.
4391 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
4392 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4394 if (Ty->isIntegerTy())
4397 // The only other support type is pointer.
4398 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4399 return getDataLayout().getIndexType(Ty);
4402 Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const {
4403 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4406 bool ScalarEvolution::instructionCouldExistWitthOperands(const SCEV *A,
4408 /// For a valid use point to exist, the defining scope of one operand
4409 /// must dominate the other.
4410 bool PreciseA, PreciseB;
4411 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4412 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4413 if (!PreciseA || !PreciseB)
4416 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4417 DT.dominates(ScopeB, ScopeA);
4421 const SCEV *ScalarEvolution::getCouldNotCompute() {
4422 return CouldNotCompute.get();
4425 bool ScalarEvolution::checkValidity(const SCEV *S) const {
4426 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4427 auto *SU = dyn_cast<SCEVUnknown>(S);
4428 return SU && SU->getValue() == nullptr;
4431 return !ContainsNulls;
4434 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
4435 HasRecMapType::iterator I = HasRecMap.find(S);
4436 if (I != HasRecMap.end())
4440 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4441 HasRecMap.insert({S, FoundAddRec});
4445 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4446 /// by the value and offset from any ValueOffsetPair in the set.
4447 ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4448 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4449 if (SI == ExprValueMap.end())
4450 return std::nullopt;
4451 return SI->second.getArrayRef();
4454 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4455 /// cannot be used separately. eraseValueFromMap should be used to remove
4456 /// V from ValueExprMap and ExprValueMap at the same time.
4457 void ScalarEvolution::eraseValueFromMap(Value *V) {
4458 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4459 if (I != ValueExprMap.end()) {
4460 auto EVIt = ExprValueMap.find(I->second);
4461 bool Removed = EVIt->second.remove(V);
4463 assert(Removed && "Value not in ExprValueMap?");
4464 ValueExprMap.erase(I);
4468 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4469 // A recursive query may have already computed the SCEV. It should be
4470 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4471 // inferred nowrap flags.
4472 auto It = ValueExprMap.find_as(V);
4473 if (It == ValueExprMap.end()) {
4474 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4475 ExprValueMap[S].insert(V);
4479 /// Determine whether this instruction is either not SCEVable or will always
4480 /// produce a SCEVUnknown. We do not have to walk past such instructions when
4482 static bool isAlwaysUnknown(const Instruction *I) {
4483 switch (I->getOpcode()) {
4484 case Instruction::Load:
4491 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4492 /// create a new one.
4493 const SCEV *ScalarEvolution::getSCEV(Value *V) {
4494 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4496 if (const SCEV *S = getExistingSCEV(V))
4498 const SCEV *S = createSCEVIter(V);
4499 assert((!isa<Instruction>(V) || !isAlwaysUnknown(cast<Instruction>(V)) ||
4500 isa<SCEVUnknown>(S)) &&
4501 "isAlwaysUnknown() instruction is not SCEVUnknown");
4505 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4506 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4508 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4509 if (I != ValueExprMap.end()) {
4510 const SCEV *S = I->second;
4511 assert(checkValidity(S) &&
4512 "existing SCEV has not been properly invalidated");
4518 /// Return a SCEV corresponding to -V = -1*V
4519 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
4520 SCEV::NoWrapFlags Flags) {
4521 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4523 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4525 Type *Ty = V->getType();
4526 Ty = getEffectiveSCEVType(Ty);
4527 return getMulExpr(V, getMinusOne(Ty), Flags);
4530 /// If Expr computes ~A, return A else return nullptr
4531 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4532 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4533 if (!Add || Add->getNumOperands() != 2 ||
4534 !Add->getOperand(0)->isAllOnesValue())
4537 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4538 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4539 !AddRHS->getOperand(0)->isAllOnesValue())
4542 return AddRHS->getOperand(1);
4545 /// Return a SCEV corresponding to ~V = -1-V
4546 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
4547 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4549 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4551 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4553 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4554 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4555 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4556 SmallVector<const SCEV *, 2> MatchedOperands;
4557 for (const SCEV *Operand : MME->operands()) {
4558 const SCEV *Matched = MatchNotExpr(Operand);
4560 return (const SCEV *)nullptr;
4561 MatchedOperands.push_back(Matched);
4563 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4566 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4570 Type *Ty = V->getType();
4571 Ty = getEffectiveSCEVType(Ty);
4572 return getMinusSCEV(getMinusOne(Ty), V);
4575 const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
4576 assert(P->getType()->isPointerTy());
4578 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4579 // The base of an AddRec is the first operand.
4580 SmallVector<const SCEV *> Ops{AddRec->operands()};
4581 Ops[0] = removePointerBase(Ops[0]);
4582 // Don't try to transfer nowrap flags for now. We could in some cases
4583 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4584 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4586 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4587 // The base of an Add is the pointer operand.
4588 SmallVector<const SCEV *> Ops{Add->operands()};
4589 const SCEV **PtrOp = nullptr;
4590 for (const SCEV *&AddOp : Ops) {
4591 if (AddOp->getType()->isPointerTy()) {
4592 assert(!PtrOp && "Cannot have multiple pointer ops");
4596 *PtrOp = removePointerBase(*PtrOp);
4597 // Don't try to transfer nowrap flags for now. We could in some cases
4598 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4599 return getAddExpr(Ops);
4601 // Any other expression must be a pointer base.
4602 return getZero(P->getType());
4605 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4606 SCEV::NoWrapFlags Flags,
4608 // Fast path: X - X --> 0.
4610 return getZero(LHS->getType());
4612 // If we subtract two pointers with different pointer bases, bail.
4613 // Eventually, we're going to add an assertion to getMulExpr that we
4614 // can't multiply by a pointer.
4615 if (RHS->getType()->isPointerTy()) {
4616 if (!LHS->getType()->isPointerTy() ||
4617 getPointerBase(LHS) != getPointerBase(RHS))
4618 return getCouldNotCompute();
4619 LHS = removePointerBase(LHS);
4620 RHS = removePointerBase(RHS);
4623 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4624 // makes it so that we cannot make much use of NUW.
4625 auto AddFlags = SCEV::FlagAnyWrap;
4626 const bool RHSIsNotMinSigned =
4627 !getSignedRangeMin(RHS).isMinSignedValue();
4628 if (hasFlags(Flags, SCEV::FlagNSW)) {
4629 // Let M be the minimum representable signed value. Then (-1)*RHS
4630 // signed-wraps if and only if RHS is M. That can happen even for
4631 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4632 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4633 // (-1)*RHS, we need to prove that RHS != M.
4635 // If LHS is non-negative and we know that LHS - RHS does not
4636 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4637 // either by proving that RHS > M or that LHS >= 0.
4638 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4639 AddFlags = SCEV::FlagNSW;
4643 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4644 // RHS is NSW and LHS >= 0.
4646 // The difficulty here is that the NSW flag may have been proven
4647 // relative to a loop that is to be found in a recurrence in LHS and
4648 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4649 // larger scope than intended.
4650 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4652 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4655 const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
4657 Type *SrcTy = V->getType();
4658 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4659 "Cannot truncate or zero extend with non-integer arguments!");
4660 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4661 return V; // No conversion
4662 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4663 return getTruncateExpr(V, Ty, Depth);
4664 return getZeroExtendExpr(V, Ty, Depth);
4667 const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
4669 Type *SrcTy = V->getType();
4670 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4671 "Cannot truncate or zero extend with non-integer arguments!");
4672 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4673 return V; // No conversion
4674 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4675 return getTruncateExpr(V, Ty, Depth);
4676 return getSignExtendExpr(V, Ty, Depth);
4680 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
4681 Type *SrcTy = V->getType();
4682 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4683 "Cannot noop or zero extend with non-integer arguments!");
4684 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4685 "getNoopOrZeroExtend cannot truncate!");
4686 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4687 return V; // No conversion
4688 return getZeroExtendExpr(V, Ty);
4692 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
4693 Type *SrcTy = V->getType();
4694 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4695 "Cannot noop or sign extend with non-integer arguments!");
4696 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4697 "getNoopOrSignExtend cannot truncate!");
4698 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4699 return V; // No conversion
4700 return getSignExtendExpr(V, Ty);
4704 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
4705 Type *SrcTy = V->getType();
4706 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4707 "Cannot noop or any extend with non-integer arguments!");
4708 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4709 "getNoopOrAnyExtend cannot truncate!");
4710 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4711 return V; // No conversion
4712 return getAnyExtendExpr(V, Ty);
4716 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
4717 Type *SrcTy = V->getType();
4718 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4719 "Cannot truncate or noop with non-integer arguments!");
4720 assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4721 "getTruncateOrNoop cannot extend!");
4722 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4723 return V; // No conversion
4724 return getTruncateExpr(V, Ty);
4727 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
4729 const SCEV *PromotedLHS = LHS;
4730 const SCEV *PromotedRHS = RHS;
4732 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4733 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4735 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4737 return getUMaxExpr(PromotedLHS, PromotedRHS);
4740 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
4743 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4744 return getUMinFromMismatchedTypes(Ops, Sequential);
4748 ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
4750 assert(!Ops.empty() && "At least one operand must be!");
4752 if (Ops.size() == 1)
4755 // Find the max type first.
4756 Type *MaxType = nullptr;
4757 for (const auto *S : Ops)
4759 MaxType = getWiderType(MaxType, S->getType());
4761 MaxType = S->getType();
4762 assert(MaxType && "Failed to find maximum type!");
4764 // Extend all ops to max type.
4765 SmallVector<const SCEV *, 2> PromotedOps;
4766 for (const auto *S : Ops)
4767 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4770 return getUMinExpr(PromotedOps, Sequential);
4773 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
4774 // A pointer operand may evaluate to a nonpointer expression, such as null.
4775 if (!V->getType()->isPointerTy())
4779 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4780 V = AddRec->getStart();
4781 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4782 const SCEV *PtrOp = nullptr;
4783 for (const SCEV *AddOp : Add->operands()) {
4784 if (AddOp->getType()->isPointerTy()) {
4785 assert(!PtrOp && "Cannot have multiple pointer ops");
4789 assert(PtrOp && "Must have pointer op");
4791 } else // Not something we can look further into.
4796 /// Push users of the given Instruction onto the given Worklist.
4797 static void PushDefUseChildren(Instruction *I,
4798 SmallVectorImpl<Instruction *> &Worklist,
4799 SmallPtrSetImpl<Instruction *> &Visited) {
4800 // Push the def-use children onto the Worklist stack.
4801 for (User *U : I->users()) {
4802 auto *UserInsn = cast<Instruction>(U);
4803 if (isAlwaysUnknown(UserInsn))
4805 if (Visited.insert(UserInsn).second)
4806 Worklist.push_back(UserInsn);
4812 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4813 /// expression in case its Loop is L. If it is not L then
4814 /// if IgnoreOtherLoops is true then use AddRec itself
4815 /// otherwise rewrite cannot be done.
4816 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4817 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4819 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4820 bool IgnoreOtherLoops = true) {
4821 SCEVInitRewriter Rewriter(L, SE);
4822 const SCEV *Result = Rewriter.visit(S);
4823 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4824 return SE.getCouldNotCompute();
4825 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4826 ? SE.getCouldNotCompute()
4830 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4831 if (!SE.isLoopInvariant(Expr, L))
4832 SeenLoopVariantSCEVUnknown = true;
4836 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4837 // Only re-write AddRecExprs for this loop.
4838 if (Expr->getLoop() == L)
4839 return Expr->getStart();
4840 SeenOtherLoops = true;
4844 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4846 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4849 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4850 : SCEVRewriteVisitor(SE), L(L) {}
4853 bool SeenLoopVariantSCEVUnknown = false;
4854 bool SeenOtherLoops = false;
4857 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4858 /// increment expression in case its Loop is L. If it is not L then
4859 /// use AddRec itself.
4860 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4861 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4863 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4864 SCEVPostIncRewriter Rewriter(L, SE);
4865 const SCEV *Result = Rewriter.visit(S);
4866 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4867 ? SE.getCouldNotCompute()
4871 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4872 if (!SE.isLoopInvariant(Expr, L))
4873 SeenLoopVariantSCEVUnknown = true;
4877 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4878 // Only re-write AddRecExprs for this loop.
4879 if (Expr->getLoop() == L)
4880 return Expr->getPostIncExpr(SE);
4881 SeenOtherLoops = true;
4885 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4887 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4890 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4891 : SCEVRewriteVisitor(SE), L(L) {}
4894 bool SeenLoopVariantSCEVUnknown = false;
4895 bool SeenOtherLoops = false;
4898 /// This class evaluates the compare condition by matching it against the
4899 /// condition of loop latch. If there is a match we assume a true value
4900 /// for the condition while building SCEV nodes.
4901 class SCEVBackedgeConditionFolder
4902 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4904 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4905 ScalarEvolution &SE) {
4906 bool IsPosBECond = false;
4907 Value *BECond = nullptr;
4908 if (BasicBlock *Latch = L->getLoopLatch()) {
4909 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4910 if (BI && BI->isConditional()) {
4911 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4912 "Both outgoing branches should not target same header!");
4913 BECond = BI->getCondition();
4914 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4919 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4920 return Rewriter.visit(S);
4923 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4924 const SCEV *Result = Expr;
4925 bool InvariantF = SE.isLoopInvariant(Expr, L);
4928 Instruction *I = cast<Instruction>(Expr->getValue());
4929 switch (I->getOpcode()) {
4930 case Instruction::Select: {
4931 SelectInst *SI = cast<SelectInst>(I);
4932 std::optional<const SCEV *> Res =
4933 compareWithBackedgeCondition(SI->getCondition());
4935 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
4936 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4941 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4952 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4953 bool IsPosBECond, ScalarEvolution &SE)
4954 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4955 IsPositiveBECond(IsPosBECond) {}
4957 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4960 /// Loop back condition.
4961 Value *BackedgeCond = nullptr;
4962 /// Set to true if loop back is on positive branch condition.
4963 bool IsPositiveBECond;
4966 std::optional<const SCEV *>
4967 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4969 // If value matches the backedge condition for loop latch,
4970 // then return a constant evolution node based on loopback
4972 if (BackedgeCond == IC)
4973 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4974 : SE.getZero(Type::getInt1Ty(SE.getContext()));
4975 return std::nullopt;
4978 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4980 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4981 ScalarEvolution &SE) {
4982 SCEVShiftRewriter Rewriter(L, SE);
4983 const SCEV *Result = Rewriter.visit(S);
4984 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4987 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4988 // Only allow AddRecExprs for this loop.
4989 if (!SE.isLoopInvariant(Expr, L))
4994 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4995 if (Expr->getLoop() == L && Expr->isAffine())
4996 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5001 bool isValid() { return Valid; }
5004 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5005 : SCEVRewriteVisitor(SE), L(L) {}
5011 } // end anonymous namespace
5014 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5015 if (!AR->isAffine())
5016 return SCEV::FlagAnyWrap;
5018 using OBO = OverflowingBinaryOperator;
5020 SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
5022 if (!AR->hasNoSelfWrap()) {
5023 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5024 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5025 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5026 const APInt &BECountAP = BECountMax->getAPInt();
5027 unsigned NoOverflowBitWidth =
5028 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5029 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5030 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNW);
5034 if (!AR->hasNoSignedWrap()) {
5035 ConstantRange AddRecRange = getSignedRange(AR);
5036 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5038 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
5039 Instruction::Add, IncRange, OBO::NoSignedWrap);
5040 if (NSWRegion.contains(AddRecRange))
5041 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
5044 if (!AR->hasNoUnsignedWrap()) {
5045 ConstantRange AddRecRange = getUnsignedRange(AR);
5046 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5048 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
5049 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5050 if (NUWRegion.contains(AddRecRange))
5051 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
5058 ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5059 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
5061 if (AR->hasNoSignedWrap())
5064 if (!AR->isAffine())
5067 // This function can be expensive, only try to prove NSW once per AddRec.
5068 if (!SignedWrapViaInductionTried.insert(AR).second)
5071 const SCEV *Step = AR->getStepRecurrence(*this);
5072 const Loop *L = AR->getLoop();
5074 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5075 // Note that this serves two purposes: It filters out loops that are
5076 // simply not analyzable, and it covers the case where this code is
5077 // being called from within backedge-taken count analysis, such that
5078 // attempting to ask for the backedge-taken count would likely result
5079 // in infinite recursion. In the later case, the analysis code will
5080 // cope with a conservative value, and it will take care to purge
5081 // that value once it has finished.
5082 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5084 // Normally, in the cases we can prove no-overflow via a
5085 // backedge guarding condition, we can also compute a backedge
5086 // taken count for the loop. The exceptions are assumptions and
5087 // guards present in the loop -- SCEV is not great at exploiting
5088 // these to compute max backedge taken counts, but can still use
5089 // these to prove lack of overflow. Use this fact to avoid
5090 // doing extra work that may not pay off.
5092 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5093 AC.assumptions().empty())
5096 // If the backedge is guarded by a comparison with the pre-inc value the
5097 // addrec is safe. Also, if the entry is guarded by a comparison with the
5098 // start value and the backedge is guarded by a comparison with the post-inc
5099 // value, the addrec is safe.
5100 ICmpInst::Predicate Pred;
5101 const SCEV *OverflowLimit =
5102 getSignedOverflowLimitForStep(Step, &Pred, this);
5103 if (OverflowLimit &&
5104 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5105 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5106 Result = setFlags(Result, SCEV::FlagNSW);
5111 ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5112 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
5114 if (AR->hasNoUnsignedWrap())
5117 if (!AR->isAffine())
5120 // This function can be expensive, only try to prove NUW once per AddRec.
5121 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5124 const SCEV *Step = AR->getStepRecurrence(*this);
5125 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5126 const Loop *L = AR->getLoop();
5128 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5129 // Note that this serves two purposes: It filters out loops that are
5130 // simply not analyzable, and it covers the case where this code is
5131 // being called from within backedge-taken count analysis, such that
5132 // attempting to ask for the backedge-taken count would likely result
5133 // in infinite recursion. In the later case, the analysis code will
5134 // cope with a conservative value, and it will take care to purge
5135 // that value once it has finished.
5136 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5138 // Normally, in the cases we can prove no-overflow via a
5139 // backedge guarding condition, we can also compute a backedge
5140 // taken count for the loop. The exceptions are assumptions and
5141 // guards present in the loop -- SCEV is not great at exploiting
5142 // these to compute max backedge taken counts, but can still use
5143 // these to prove lack of overflow. Use this fact to avoid
5144 // doing extra work that may not pay off.
5146 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5147 AC.assumptions().empty())
5150 // If the backedge is guarded by a comparison with the pre-inc value the
5151 // addrec is safe. Also, if the entry is guarded by a comparison with the
5152 // start value and the backedge is guarded by a comparison with the post-inc
5153 // value, the addrec is safe.
5154 if (isKnownPositive(Step)) {
5155 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5156 getUnsignedRangeMax(Step));
5157 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
5158 isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
5159 Result = setFlags(Result, SCEV::FlagNUW);
5168 /// Represents an abstract binary operation. This may exist as a
5169 /// normal instruction or constant expression, or may have been
5170 /// derived from an expression tree.
5178 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5179 /// constant expression.
5180 Operator *Op = nullptr;
5182 explicit BinaryOp(Operator *Op)
5183 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5185 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5186 IsNSW = OBO->hasNoSignedWrap();
5187 IsNUW = OBO->hasNoUnsignedWrap();
5191 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5193 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5196 } // end anonymous namespace
5198 /// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5199 static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5200 AssumptionCache &AC,
5201 const DominatorTree &DT,
5202 const Instruction *CxtI) {
5203 auto *Op = dyn_cast<Operator>(V);
5205 return std::nullopt;
5207 // Implementation detail: all the cleverness here should happen without
5208 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5209 // SCEV expressions when possible, and we should not break that.
5211 switch (Op->getOpcode()) {
5212 case Instruction::Add:
5213 case Instruction::Sub:
5214 case Instruction::Mul:
5215 case Instruction::UDiv:
5216 case Instruction::URem:
5217 case Instruction::And:
5218 case Instruction::AShr:
5219 case Instruction::Shl:
5220 return BinaryOp(Op);
5222 case Instruction::Or: {
5223 // LLVM loves to convert `add` of operands with no common bits
5224 // into an `or`. But SCEV really doesn't deal with `or` that well,
5225 // so try extra hard to recognize this `or` as an `add`.
5226 if (haveNoCommonBitsSet(Op->getOperand(0), Op->getOperand(1), DL, &AC, CxtI,
5227 &DT, /*UseInstrInfo=*/true))
5228 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5229 /*IsNSW=*/true, /*IsNUW=*/true);
5230 return BinaryOp(Op);
5233 case Instruction::Xor:
5234 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5235 // If the RHS of the xor is a signmask, then this is just an add.
5236 // Instcombine turns add of signmask into xor as a strength reduction step.
5237 if (RHSC->getValue().isSignMask())
5238 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5239 // Binary `xor` is a bit-wise `add`.
5240 if (V->getType()->isIntegerTy(1))
5241 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5242 return BinaryOp(Op);
5244 case Instruction::LShr:
5245 // Turn logical shift right of a constant into a unsigned divide.
5246 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5247 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5249 // If the shift count is not less than the bitwidth, the result of
5250 // the shift is undefined. Don't try to analyze it, because the
5251 // resolution chosen here may differ from the resolution chosen in
5252 // other parts of the compiler.
5253 if (SA->getValue().ult(BitWidth)) {
5255 ConstantInt::get(SA->getContext(),
5256 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5257 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5260 return BinaryOp(Op);
5262 case Instruction::ExtractValue: {
5263 auto *EVI = cast<ExtractValueInst>(Op);
5264 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5267 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5271 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5272 bool Signed = WO->isSigned();
5273 // TODO: Should add nuw/nsw flags for mul as well.
5274 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5275 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5277 // Now that we know that all uses of the arithmetic-result component of
5278 // CI are guarded by the overflow check, we can go ahead and pretend
5279 // that the arithmetic is non-overflowing.
5280 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5281 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5288 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5289 // semantics as a Sub, return a binary sub expression.
5290 if (auto *II = dyn_cast<IntrinsicInst>(V))
5291 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5292 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5294 return std::nullopt;
5297 /// Helper function to createAddRecFromPHIWithCasts. We have a phi
5298 /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5299 /// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5300 /// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5301 /// follows one of the following patterns:
5302 /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5303 /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5304 /// If the SCEV expression of \p Op conforms with one of the expected patterns
5305 /// we return the type of the truncation operation, and indicate whether the
5306 /// truncated type should be treated as signed/unsigned by setting
5307 /// \p Signed to true/false, respectively.
5308 static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5309 bool &Signed, ScalarEvolution &SE) {
5310 // The case where Op == SymbolicPHI (that is, with no type conversions on
5311 // the way) is handled by the regular add recurrence creating logic and
5312 // would have already been triggered in createAddRecForPHI. Reaching it here
5313 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5314 // because one of the other operands of the SCEVAddExpr updating this PHI is
5317 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5318 // this case predicates that allow us to prove that Op == SymbolicPHI will
5320 if (Op == SymbolicPHI)
5323 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5324 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5325 if (SourceBits != NewBits)
5328 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5329 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5332 const SCEVTruncateExpr *Trunc =
5333 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5334 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5337 const SCEV *X = Trunc->getOperand();
5338 if (X != SymbolicPHI)
5340 Signed = SExt != nullptr;
5341 return Trunc->getType();
5344 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5345 if (!PN->getType()->isIntegerTy())
5347 const Loop *L = LI.getLoopFor(PN->getParent());
5348 if (!L || L->getHeader() != PN->getParent())
5353 // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5354 // computation that updates the phi follows the following pattern:
5355 // (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5356 // which correspond to a phi->trunc->sext/zext->add->phi update chain.
5357 // If so, try to see if it can be rewritten as an AddRecExpr under some
5358 // Predicates. If successful, return them as a pair. Also cache the results
5361 // Example usage scenario:
5362 // Say the Rewriter is called for the following SCEV:
5363 // 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5365 // %X = phi i64 (%Start, %BEValue)
5366 // It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5367 // and call this function with %SymbolicPHI = %X.
5369 // The analysis will find that the value coming around the backedge has
5370 // the following SCEV:
5371 // BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5372 // Upon concluding that this matches the desired pattern, the function
5373 // will return the pair {NewAddRec, SmallPredsVec} where:
5374 // NewAddRec = {%Start,+,%Step}
5375 // SmallPredsVec = {P1, P2, P3} as follows:
5376 // P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5377 // P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5378 // P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5379 // The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5380 // under the predicates {P1,P2,P3}.
5381 // This predicated rewrite will be cached in PredicatedSCEVRewrites:
5382 // PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5386 // 1) Extend the Induction descriptor to also support inductions that involve
5387 // casts: When needed (namely, when we are called in the context of the
5388 // vectorizer induction analysis), a Set of cast instructions will be
5389 // populated by this method, and provided back to isInductionPHI. This is
5390 // needed to allow the vectorizer to properly record them to be ignored by
5391 // the cost model and to avoid vectorizing them (otherwise these casts,
5392 // which are redundant under the runtime overflow checks, will be
5393 // vectorized, which can be costly).
5395 // 2) Support additional induction/PHISCEV patterns: We also want to support
5396 // inductions where the sext-trunc / zext-trunc operations (partly) occur
5397 // after the induction update operation (the induction increment):
5399 // (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5400 // which correspond to a phi->add->trunc->sext/zext->phi update chain.
5402 // (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5403 // which correspond to a phi->trunc->add->sext/zext->phi update chain.
5405 // 3) Outline common code with createAddRecFromPHI to avoid duplication.
5406 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5407 ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5408 SmallVector<const SCEVPredicate *, 3> Predicates;
5410 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5411 // return an AddRec expression under some predicate.
5413 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5414 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5415 assert(L && "Expecting an integer loop header phi");
5417 // The loop may have multiple entrances or multiple exits; we can analyze
5418 // this phi as an addrec if it has a unique entry value and a unique
5420 Value *BEValueV = nullptr, *StartValueV = nullptr;
5421 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5422 Value *V = PN->getIncomingValue(i);
5423 if (L->contains(PN->getIncomingBlock(i))) {
5426 } else if (BEValueV != V) {
5430 } else if (!StartValueV) {
5432 } else if (StartValueV != V) {
5433 StartValueV = nullptr;
5437 if (!BEValueV || !StartValueV)
5438 return std::nullopt;
5440 const SCEV *BEValue = getSCEV(BEValueV);
5442 // If the value coming around the backedge is an add with the symbolic
5443 // value we just inserted, possibly with casts that we can ignore under
5444 // an appropriate runtime guard, then we found a simple induction variable!
5445 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5447 return std::nullopt;
5449 // If there is a single occurrence of the symbolic value, possibly
5450 // casted, replace it with a recurrence.
5451 unsigned FoundIndex = Add->getNumOperands();
5452 Type *TruncTy = nullptr;
5454 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5456 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5457 if (FoundIndex == e) {
5462 if (FoundIndex == Add->getNumOperands())
5463 return std::nullopt;
5465 // Create an add with everything but the specified operand.
5466 SmallVector<const SCEV *, 8> Ops;
5467 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5468 if (i != FoundIndex)
5469 Ops.push_back(Add->getOperand(i));
5470 const SCEV *Accum = getAddExpr(Ops);
5472 // The runtime checks will not be valid if the step amount is
5473 // varying inside the loop.
5474 if (!isLoopInvariant(Accum, L))
5475 return std::nullopt;
5477 // *** Part2: Create the predicates
5479 // Analysis was successful: we have a phi-with-cast pattern for which we
5480 // can return an AddRec expression under the following predicates:
5482 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5483 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5484 // P2: An Equal predicate that guarantees that
5485 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5486 // P3: An Equal predicate that guarantees that
5487 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5489 // As we next prove, the above predicates guarantee that:
5490 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5493 // More formally, we want to prove that:
5494 // Expr(i+1) = Start + (i+1) * Accum
5495 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5498 // 1) Expr(0) = Start
5499 // 2) Expr(1) = Start + Accum
5500 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5501 // 3) Induction hypothesis (step i):
5502 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5506 // = Start + (i+1)*Accum
5507 // = (Start + i*Accum) + Accum
5508 // = Expr(i) + Accum
5509 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5512 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5514 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5515 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5516 // + Accum :: from P3
5518 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5519 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5521 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5522 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5524 // By induction, the same applies to all iterations 1<=i<n:
5527 // Create a truncated addrec for which we will add a no overflow check (P1).
5528 const SCEV *StartVal = getSCEV(StartValueV);
5529 const SCEV *PHISCEV =
5530 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5531 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5533 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5534 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5535 // will be constant.
5537 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5539 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5540 SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
5541 Signed ? SCEVWrapPredicate::IncrementNSSW
5542 : SCEVWrapPredicate::IncrementNUSW;
5543 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5544 Predicates.push_back(AddRecPred);
5547 // Create the Equal Predicates P2,P3:
5549 // It is possible that the predicates P2 and/or P3 are computable at
5550 // compile time due to StartVal and/or Accum being constants.
5551 // If either one is, then we can check that now and escape if either P2
5554 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5555 // for each of StartVal and Accum
5556 auto getExtendedExpr = [&](const SCEV *Expr,
5557 bool CreateSignExtend) -> const SCEV * {
5558 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5559 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5560 const SCEV *ExtendedExpr =
5561 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5562 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5563 return ExtendedExpr;
5567 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5568 // = getExtendedExpr(Expr)
5569 // Determine whether the predicate P: Expr == ExtendedExpr
5570 // is known to be false at compile time
5571 auto PredIsKnownFalse = [&](const SCEV *Expr,
5572 const SCEV *ExtendedExpr) -> bool {
5573 return Expr != ExtendedExpr &&
5574 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5577 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5578 if (PredIsKnownFalse(StartVal, StartExtended)) {
5579 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5580 return std::nullopt;
5583 // The Step is always Signed (because the overflow checks are either
5585 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5586 if (PredIsKnownFalse(Accum, AccumExtended)) {
5587 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5588 return std::nullopt;
5591 auto AppendPredicate = [&](const SCEV *Expr,
5592 const SCEV *ExtendedExpr) -> void {
5593 if (Expr != ExtendedExpr &&
5594 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5595 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5596 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5597 Predicates.push_back(Pred);
5601 AppendPredicate(StartVal, StartExtended);
5602 AppendPredicate(Accum, AccumExtended);
5604 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5605 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5606 // into NewAR if it will also add the runtime overflow checks specified in
5608 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5610 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5611 std::make_pair(NewAR, Predicates);
5612 // Remember the result of the analysis for this SCEV at this locayyytion.
5613 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5617 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5618 ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
5619 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5620 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5622 return std::nullopt;
5624 // Check to see if we already analyzed this PHI.
5625 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5626 if (I != PredicatedSCEVRewrites.end()) {
5627 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5629 // Analysis was done before and failed to create an AddRec:
5630 if (Rewrite.first == SymbolicPHI)
5631 return std::nullopt;
5632 // Analysis was done before and succeeded to create an AddRec under
5634 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5635 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5639 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5640 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5642 // Record in the cache that the analysis failed
5644 SmallVector<const SCEVPredicate *, 3> Predicates;
5645 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5646 return std::nullopt;
5652 // FIXME: This utility is currently required because the Rewriter currently
5653 // does not rewrite this expression:
5654 // {0, +, (sext ix (trunc iy to ix) to iy)}
5655 // into {0, +, %step},
5656 // even when the following Equal predicate exists:
5657 // "%step == (sext ix (trunc iy to ix) to iy)".
5658 bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5659 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5663 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5664 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5665 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5670 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5671 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5676 /// A helper function for createAddRecFromPHI to handle simple cases.
5678 /// This function tries to find an AddRec expression for the simplest (yet most
5679 /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5680 /// If it fails, createAddRecFromPHI will use a more general, but slow,
5681 /// technique for finding the AddRec expression.
5682 const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5684 Value *StartValueV) {
5685 const Loop *L = LI.getLoopFor(PN->getParent());
5686 assert(L && L->getHeader() == PN->getParent());
5687 assert(BEValueV && StartValueV);
5689 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5693 if (BO->Opcode != Instruction::Add)
5696 const SCEV *Accum = nullptr;
5697 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5698 Accum = getSCEV(BO->RHS);
5699 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5700 Accum = getSCEV(BO->LHS);
5705 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5707 Flags = setFlags(Flags, SCEV::FlagNUW);
5709 Flags = setFlags(Flags, SCEV::FlagNSW);
5711 const SCEV *StartVal = getSCEV(StartValueV);
5712 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5713 insertValueToMap(PN, PHISCEV);
5715 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5716 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5717 (SCEV::NoWrapFlags)(AR->getNoWrapFlags() |
5718 proveNoWrapViaConstantRanges(AR)));
5721 // We can add Flags to the post-inc expression only if we
5722 // know that it is *undefined behavior* for BEValueV to
5724 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5725 assert(isLoopInvariant(Accum, L) &&
5726 "Accum is defined outside L, but is not invariant?");
5727 if (isAddRecNeverPoison(BEInst, L))
5728 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5734 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5735 const Loop *L = LI.getLoopFor(PN->getParent());
5736 if (!L || L->getHeader() != PN->getParent())
5739 // The loop may have multiple entrances or multiple exits; we can analyze
5740 // this phi as an addrec if it has a unique entry value and a unique
5742 Value *BEValueV = nullptr, *StartValueV = nullptr;
5743 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5744 Value *V = PN->getIncomingValue(i);
5745 if (L->contains(PN->getIncomingBlock(i))) {
5748 } else if (BEValueV != V) {
5752 } else if (!StartValueV) {
5754 } else if (StartValueV != V) {
5755 StartValueV = nullptr;
5759 if (!BEValueV || !StartValueV)
5762 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5763 "PHI node already processed?");
5765 // First, try to find AddRec expression without creating a fictituos symbolic
5767 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5770 // Handle PHI node value symbolically.
5771 const SCEV *SymbolicName = getUnknown(PN);
5772 insertValueToMap(PN, SymbolicName);
5774 // Using this symbolic name for the PHI, analyze the value coming around
5776 const SCEV *BEValue = getSCEV(BEValueV);
5778 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5779 // has a special value for the first iteration of the loop.
5781 // If the value coming around the backedge is an add with the symbolic
5782 // value we just inserted, then we found a simple induction variable!
5783 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5784 // If there is a single occurrence of the symbolic value, replace it
5785 // with a recurrence.
5786 unsigned FoundIndex = Add->getNumOperands();
5787 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5788 if (Add->getOperand(i) == SymbolicName)
5789 if (FoundIndex == e) {
5794 if (FoundIndex != Add->getNumOperands()) {
5795 // Create an add with everything but the specified operand.
5796 SmallVector<const SCEV *, 8> Ops;
5797 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5798 if (i != FoundIndex)
5799 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5801 const SCEV *Accum = getAddExpr(Ops);
5803 // This is not a valid addrec if the step amount is varying each
5804 // loop iteration, but is not itself an addrec in this loop.
5805 if (isLoopInvariant(Accum, L) ||
5806 (isa<SCEVAddRecExpr>(Accum) &&
5807 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5808 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5810 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5811 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5813 Flags = setFlags(Flags, SCEV::FlagNUW);
5815 Flags = setFlags(Flags, SCEV::FlagNSW);
5817 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5818 // If the increment is an inbounds GEP, then we know the address
5819 // space cannot be wrapped around. We cannot make any guarantee
5820 // about signed or unsigned overflow because pointers are
5821 // unsigned but we may have a negative index from the base
5822 // pointer. We can guarantee that no unsigned wrap occurs if the
5823 // indices form a positive value.
5824 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5825 Flags = setFlags(Flags, SCEV::FlagNW);
5826 if (isKnownPositive(Accum))
5827 Flags = setFlags(Flags, SCEV::FlagNUW);
5830 // We cannot transfer nuw and nsw flags from subtraction
5831 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5835 const SCEV *StartVal = getSCEV(StartValueV);
5836 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5838 // Okay, for the entire analysis of this edge we assumed the PHI
5839 // to be symbolic. We now need to go back and purge all of the
5840 // entries for the scalars that use the symbolic expression.
5841 forgetMemoizedResults(SymbolicName);
5842 insertValueToMap(PN, PHISCEV);
5844 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5845 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5846 (SCEV::NoWrapFlags)(AR->getNoWrapFlags() |
5847 proveNoWrapViaConstantRanges(AR)));
5850 // We can add Flags to the post-inc expression only if we
5851 // know that it is *undefined behavior* for BEValueV to
5853 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5854 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5855 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5861 // Otherwise, this could be a loop like this:
5862 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5863 // In this case, j = {1,+,1} and BEValue is j.
5864 // Because the other in-value of i (0) fits the evolution of BEValue
5865 // i really is an addrec evolution.
5867 // We can generalize this saying that i is the shifted value of BEValue
5868 // by one iteration:
5869 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5870 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5871 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5872 if (Shifted != getCouldNotCompute() &&
5873 Start != getCouldNotCompute()) {
5874 const SCEV *StartVal = getSCEV(StartValueV);
5875 if (Start == StartVal) {
5876 // Okay, for the entire analysis of this edge we assumed the PHI
5877 // to be symbolic. We now need to go back and purge all of the
5878 // entries for the scalars that use the symbolic expression.
5879 forgetMemoizedResults(SymbolicName);
5880 insertValueToMap(PN, Shifted);
5886 // Remove the temporary PHI node SCEV that has been inserted while intending
5887 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5888 // as it will prevent later (possibly simpler) SCEV expressions to be added
5889 // to the ValueExprMap.
5890 eraseValueFromMap(PN);
5895 // Try to match a control flow sequence that branches out at BI and merges back
5896 // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5898 static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
5899 Value *&C, Value *&LHS, Value *&RHS) {
5900 C = BI->getCondition();
5902 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5903 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5905 if (!LeftEdge.isSingleEdge())
5908 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5910 Use &LeftUse = Merge->getOperandUse(0);
5911 Use &RightUse = Merge->getOperandUse(1);
5913 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5919 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5928 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5930 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5931 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5934 // br %cond, label %left, label %right
5940 // V = phi [ %x, %left ], [ %y, %right ]
5942 // as "select %cond, %x, %y"
5944 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5945 assert(IDom && "At least the entry block should dominate PN");
5947 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5948 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5950 if (BI && BI->isConditional() &&
5951 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
5952 properlyDominates(getSCEV(LHS), PN->getParent()) &&
5953 properlyDominates(getSCEV(RHS), PN->getParent()))
5954 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
5960 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
5961 if (const SCEV *S = createAddRecFromPHI(PN))
5964 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
5967 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
5970 // If it's not a loop phi, we can't handle it yet.
5971 return getUnknown(PN);
5974 bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
5975 SCEVTypes RootKind) {
5976 struct FindClosure {
5977 const SCEV *OperandToFind;
5978 const SCEVTypes RootKind; // Must be a sequential min/max expression.
5979 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
5983 bool canRecurseInto(SCEVTypes Kind) const {
5984 // We can only recurse into the SCEV expression of the same effective type
5985 // as the type of our root SCEV expression, and into zero-extensions.
5986 return RootKind == Kind || NonSequentialRootKind == Kind ||
5987 scZeroExtend == Kind;
5990 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
5991 : OperandToFind(OperandToFind), RootKind(RootKind),
5992 NonSequentialRootKind(
5993 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
5996 bool follow(const SCEV *S) {
5997 Found = S == OperandToFind;
5999 return !isDone() && canRecurseInto(S->getSCEVType());
6002 bool isDone() const { return Found; }
6005 FindClosure FC(OperandToFind, RootKind);
6010 std::optional<const SCEV *>
6011 ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6015 // Try to match some simple smax or umax patterns.
6018 Value *LHS = ICI->getOperand(0);
6019 Value *RHS = ICI->getOperand(1);
6021 switch (ICI->getPredicate()) {
6022 case ICmpInst::ICMP_SLT:
6023 case ICmpInst::ICMP_SLE:
6024 case ICmpInst::ICMP_ULT:
6025 case ICmpInst::ICMP_ULE:
6026 std::swap(LHS, RHS);
6028 case ICmpInst::ICMP_SGT:
6029 case ICmpInst::ICMP_SGE:
6030 case ICmpInst::ICMP_UGT:
6031 case ICmpInst::ICMP_UGE:
6032 // a > b ? a+x : b+x -> max(a, b)+x
6033 // a > b ? b+x : a+x -> min(a, b)+x
6034 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty)) {
6035 bool Signed = ICI->isSigned();
6036 const SCEV *LA = getSCEV(TrueVal);
6037 const SCEV *RA = getSCEV(FalseVal);
6038 const SCEV *LS = getSCEV(LHS);
6039 const SCEV *RS = getSCEV(RHS);
6040 if (LA->getType()->isPointerTy()) {
6041 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6042 // Need to make sure we can't produce weird expressions involving
6043 // negated pointers.
6044 if (LA == LS && RA == RS)
6045 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6046 if (LA == RS && RA == LS)
6047 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6049 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6050 if (Op->getType()->isPointerTy()) {
6051 Op = getLosslessPtrToIntExpr(Op);
6052 if (isa<SCEVCouldNotCompute>(Op))
6056 Op = getNoopOrSignExtend(Op, Ty);
6058 Op = getNoopOrZeroExtend(Op, Ty);
6061 LS = CoerceOperand(LS);
6062 RS = CoerceOperand(RS);
6063 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6065 const SCEV *LDiff = getMinusSCEV(LA, LS);
6066 const SCEV *RDiff = getMinusSCEV(RA, RS);
6068 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6070 LDiff = getMinusSCEV(LA, RS);
6071 RDiff = getMinusSCEV(RA, LS);
6073 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6077 case ICmpInst::ICMP_NE:
6078 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6079 std::swap(TrueVal, FalseVal);
6081 case ICmpInst::ICMP_EQ:
6082 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6083 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty) &&
6084 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6085 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6086 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6087 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6088 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6089 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6090 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6091 return getAddExpr(getUMaxExpr(X, C), Y);
6093 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6094 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6095 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6096 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6097 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6098 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6099 const SCEV *X = getSCEV(LHS);
6100 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6101 X = ZExt->getOperand();
6102 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6103 const SCEV *FalseValExpr = getSCEV(FalseVal);
6104 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6105 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6106 /*Sequential=*/true);
6114 return std::nullopt;
6117 static std::optional<const SCEV *>
6118 createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr,
6119 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6120 assert(CondExpr->getType()->isIntegerTy(1) &&
6121 TrueExpr->getType() == FalseExpr->getType() &&
6122 TrueExpr->getType()->isIntegerTy(1) &&
6123 "Unexpected operands of a select.");
6125 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6126 // --> C + (umin_seq cond, x - C)
6128 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6129 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6130 // --> C + (umin_seq ~cond, x - C)
6132 // FIXME: while we can't legally model the case where both of the hands
6133 // are fully variable, we only require that the *difference* is constant.
6134 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6135 return std::nullopt;
6138 if (isa<SCEVConstant>(TrueExpr)) {
6139 CondExpr = SE->getNotSCEV(CondExpr);
6146 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6147 /*Sequential=*/true));
6150 static std::optional<const SCEV *>
6151 createNodeForSelectViaUMinSeq(ScalarEvolution *SE, Value *Cond, Value *TrueVal,
6153 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6154 return std::nullopt;
6156 const auto *SECond = SE->getSCEV(Cond);
6157 const auto *SETrue = SE->getSCEV(TrueVal);
6158 const auto *SEFalse = SE->getSCEV(FalseVal);
6159 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6162 const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6163 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6164 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6165 assert(TrueVal->getType() == FalseVal->getType() &&
6166 V->getType() == TrueVal->getType() &&
6167 "Types of select hands and of the result must match.");
6169 // For now, only deal with i1-typed `select`s.
6170 if (!V->getType()->isIntegerTy(1))
6171 return getUnknown(V);
6173 if (std::optional<const SCEV *> S =
6174 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6177 return getUnknown(V);
6180 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6183 // Handle "constant" branch or select. This can occur for instance when a
6184 // loop pass transforms an inner loop and moves on to process the outer loop.
6185 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6186 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6188 if (auto *I = dyn_cast<Instruction>(V)) {
6189 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6190 if (std::optional<const SCEV *> S =
6191 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6197 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6200 /// Expand GEP instructions into add and multiply operations. This allows them
6201 /// to be analyzed by regular SCEV code.
6202 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6203 assert(GEP->getSourceElementType()->isSized() &&
6204 "GEP source element type must be sized");
6206 SmallVector<const SCEV *, 4> IndexExprs;
6207 for (Value *Index : GEP->indices())
6208 IndexExprs.push_back(getSCEV(Index));
6209 return getGEPExpr(GEP, IndexExprs);
6212 APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6213 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6214 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6215 return TrailingZeros >= BitWidth
6216 ? APInt::getZero(BitWidth)
6217 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6219 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6220 // The result is GCD of all operands results.
6221 APInt Res = getConstantMultiple(N->getOperand(0));
6222 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6223 Res = APIntOps::GreatestCommonDivisor(
6224 Res, getConstantMultiple(N->getOperand(I)));
6228 switch (S->getSCEVType()) {
6230 return cast<SCEVConstant>(S)->getAPInt();
6232 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6235 return APInt(BitWidth, 1);
6237 // Only multiples that are a power of 2 will hold after truncation.
6238 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6239 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6240 return GetShiftedByZeros(TZ);
6242 case scZeroExtend: {
6243 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6244 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6246 case scSignExtend: {
6247 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6248 return getConstantMultiple(E->getOperand()).sext(BitWidth);
6251 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6252 if (M->hasNoUnsignedWrap()) {
6253 // The result is the product of all operand results.
6254 APInt Res = getConstantMultiple(M->getOperand(0));
6255 for (const SCEV *Operand : M->operands().drop_front())
6256 Res = Res * getConstantMultiple(Operand);
6260 // If there are no wrap guarentees, find the trailing zeros, which is the
6261 // sum of trailing zeros for all its operands.
6263 for (const SCEV *Operand : M->operands())
6264 TZ += getMinTrailingZeros(Operand);
6265 return GetShiftedByZeros(TZ);
6268 case scAddRecExpr: {
6269 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6270 if (N->hasNoUnsignedWrap())
6271 return GetGCDMultiple(N);
6272 // Find the trailing bits, which is the minimum of its operands.
6273 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6274 for (const SCEV *Operand : N->operands().drop_front())
6275 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6276 return GetShiftedByZeros(TZ);
6282 case scSequentialUMinExpr:
6283 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6285 // ask ValueTracking for known bits
6286 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6288 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6289 .countMinTrailingZeros();
6290 return GetShiftedByZeros(Known);
6292 case scCouldNotCompute:
6293 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6295 llvm_unreachable("Unknown SCEV kind!");
6298 APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
6299 auto I = ConstantMultipleCache.find(S);
6300 if (I != ConstantMultipleCache.end())
6303 APInt Result = getConstantMultipleImpl(S);
6304 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6305 assert(InsertPair.second && "Should insert a new key");
6306 return InsertPair.first->second;
6309 APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
6310 APInt Multiple = getConstantMultiple(S);
6311 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6314 uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
6315 return std::min(getConstantMultiple(S).countTrailingZeros(),
6316 (unsigned)getTypeSizeInBits(S->getType()));
6319 /// Helper method to assign a range to V from metadata present in the IR.
6320 static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6321 if (Instruction *I = dyn_cast<Instruction>(V))
6322 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6323 return getConstantRangeFromMetadata(*MD);
6325 return std::nullopt;
6328 void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec,
6329 SCEV::NoWrapFlags Flags) {
6330 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6331 AddRec->setNoWrapFlags(Flags);
6332 UnsignedRanges.erase(AddRec);
6333 SignedRanges.erase(AddRec);
6334 ConstantMultipleCache.erase(AddRec);
6338 ConstantRange ScalarEvolution::
6339 getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6340 const DataLayout &DL = getDataLayout();
6342 unsigned BitWidth = getTypeSizeInBits(U->getType());
6343 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6345 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6346 // use information about the trip count to improve our available range. Note
6347 // that the trip count independent cases are already handled by known bits.
6348 // WARNING: The definition of recurrence used here is subtly different than
6349 // the one used by AddRec (and thus most of this file). Step is allowed to
6350 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6351 // and other addrecs in the same loop (for non-affine addrecs). The code
6352 // below intentionally handles the case where step is not loop invariant.
6353 auto *P = dyn_cast<PHINode>(U->getValue());
6357 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6358 // even the values that are not available in these blocks may come from them,
6359 // and this leads to false-positive recurrence test.
6360 for (auto *Pred : predecessors(P->getParent()))
6361 if (!DT.isReachableFromEntry(Pred))
6365 Value *Start, *Step;
6366 if (!matchSimpleRecurrence(P, BO, Start, Step))
6369 // If we found a recurrence in reachable code, we must be in a loop. Note
6370 // that BO might be in some subloop of L, and that's completely okay.
6371 auto *L = LI.getLoopFor(P->getParent());
6372 assert(L && L->getHeader() == P->getParent());
6373 if (!L->contains(BO->getParent()))
6374 // NOTE: This bailout should be an assert instead. However, asserting
6375 // the condition here exposes a case where LoopFusion is querying SCEV
6376 // with malformed loop information during the midst of the transform.
6377 // There doesn't appear to be an obvious fix, so for the moment bailout
6378 // until the caller issue can be fixed. PR49566 tracks the bug.
6381 // TODO: Extend to other opcodes such as mul, and div
6382 switch (BO->getOpcode()) {
6385 case Instruction::AShr:
6386 case Instruction::LShr:
6387 case Instruction::Shl:
6391 if (BO->getOperand(0) != P)
6392 // TODO: Handle the power function forms some day.
6395 unsigned TC = getSmallConstantMaxTripCount(L);
6396 if (!TC || TC >= BitWidth)
6399 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6400 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6401 assert(KnownStart.getBitWidth() == BitWidth &&
6402 KnownStep.getBitWidth() == BitWidth);
6404 // Compute total shift amount, being careful of overflow and bitwidths.
6405 auto MaxShiftAmt = KnownStep.getMaxValue();
6406 APInt TCAP(BitWidth, TC-1);
6407 bool Overflow = false;
6408 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6412 switch (BO->getOpcode()) {
6414 llvm_unreachable("filtered out above");
6415 case Instruction::AShr: {
6416 // For each ashr, three cases:
6417 // shift = 0 => unchanged value
6418 // saturation => 0 or -1
6419 // other => a value closer to zero (of the same sign)
6420 // Thus, the end value is closer to zero than the start.
6421 auto KnownEnd = KnownBits::ashr(KnownStart,
6422 KnownBits::makeConstant(TotalShift));
6423 if (KnownStart.isNonNegative())
6424 // Analogous to lshr (simply not yet canonicalized)
6425 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6426 KnownStart.getMaxValue() + 1);
6427 if (KnownStart.isNegative())
6428 // End >=u Start && End <=s Start
6429 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6430 KnownEnd.getMaxValue() + 1);
6433 case Instruction::LShr: {
6434 // For each lshr, three cases:
6435 // shift = 0 => unchanged value
6437 // other => a smaller positive number
6438 // Thus, the low end of the unsigned range is the last value produced.
6439 auto KnownEnd = KnownBits::lshr(KnownStart,
6440 KnownBits::makeConstant(TotalShift));
6441 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6442 KnownStart.getMaxValue() + 1);
6444 case Instruction::Shl: {
6445 // Iff no bits are shifted out, value increases on every shift.
6446 auto KnownEnd = KnownBits::shl(KnownStart,
6447 KnownBits::makeConstant(TotalShift));
6448 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6449 return ConstantRange(KnownStart.getMinValue(),
6450 KnownEnd.getMaxValue() + 1);
6457 const ConstantRange &
6458 ScalarEvolution::getRangeRefIter(const SCEV *S,
6459 ScalarEvolution::RangeSignHint SignHint) {
6460 DenseMap<const SCEV *, ConstantRange> &Cache =
6461 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6463 SmallVector<const SCEV *> WorkList;
6464 SmallPtrSet<const SCEV *, 8> Seen;
6466 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6467 // SCEVUnknown PHI node.
6468 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6469 if (!Seen.insert(Expr).second)
6471 if (Cache.contains(Expr))
6473 switch (Expr->getSCEVType()) {
6475 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6492 case scSequentialUMinExpr:
6493 WorkList.push_back(Expr);
6495 case scCouldNotCompute:
6496 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6501 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6502 for (unsigned I = 0; I != WorkList.size(); ++I) {
6503 const SCEV *P = WorkList[I];
6504 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6505 // If it is not a `SCEVUnknown`, just recurse into operands.
6507 for (const SCEV *Op : P->operands())
6511 // `SCEVUnknown`'s require special treatment.
6512 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6513 if (!PendingPhiRangesIter.insert(P).second)
6515 for (auto &Op : reverse(P->operands()))
6516 AddToWorklist(getSCEV(Op));
6520 if (!WorkList.empty()) {
6521 // Use getRangeRef to compute ranges for items in the worklist in reverse
6522 // order. This will force ranges for earlier operands to be computed before
6523 // their users in most cases.
6524 for (const SCEV *P :
6525 reverse(make_range(WorkList.begin() + 1, WorkList.end()))) {
6526 getRangeRef(P, SignHint);
6528 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6529 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6530 PendingPhiRangesIter.erase(P);
6534 return getRangeRef(S, SignHint, 0);
6537 /// Determine the range for a particular SCEV. If SignHint is
6538 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6539 /// with a "cleaner" unsigned (resp. signed) representation.
6540 const ConstantRange &ScalarEvolution::getRangeRef(
6541 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6542 DenseMap<const SCEV *, ConstantRange> &Cache =
6543 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6545 ConstantRange::PreferredRangeType RangeType =
6546 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6547 : ConstantRange::Signed;
6549 // See if we've computed this range already.
6550 DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
6551 if (I != Cache.end())
6554 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6555 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6557 // Switch to iteratively computing the range for S, if it is part of a deeply
6558 // nested expression.
6559 if (Depth > RangeIterThreshold)
6560 return getRangeRefIter(S, SignHint);
6562 unsigned BitWidth = getTypeSizeInBits(S->getType());
6563 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6564 using OBO = OverflowingBinaryOperator;
6566 // If the value has known zeros, the maximum value will have those known zeros
6568 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6569 APInt Multiple = getNonZeroConstantMultiple(S);
6570 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6571 if (!Remainder.isZero())
6572 ConservativeResult =
6573 ConstantRange(APInt::getMinValue(BitWidth),
6574 APInt::getMaxValue(BitWidth) - Remainder + 1);
6577 uint32_t TZ = getMinTrailingZeros(S);
6579 ConservativeResult = ConstantRange(
6580 APInt::getSignedMinValue(BitWidth),
6581 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6585 switch (S->getSCEVType()) {
6587 llvm_unreachable("Already handled above.");
6589 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6591 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6592 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6595 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6597 case scZeroExtend: {
6598 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6599 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6602 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6604 case scSignExtend: {
6605 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6606 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6609 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6612 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6613 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6614 return setRange(PtrToInt, SignHint, X);
6617 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6618 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6619 unsigned WrapType = OBO::AnyWrap;
6620 if (Add->hasNoSignedWrap())
6621 WrapType |= OBO::NoSignedWrap;
6622 if (Add->hasNoUnsignedWrap())
6623 WrapType |= OBO::NoUnsignedWrap;
6624 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6625 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
6626 WrapType, RangeType);
6627 return setRange(Add, SignHint,
6628 ConservativeResult.intersectWith(X, RangeType));
6631 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6632 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6633 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6634 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
6635 return setRange(Mul, SignHint,
6636 ConservativeResult.intersectWith(X, RangeType));
6639 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6640 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6641 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6642 return setRange(UDiv, SignHint,
6643 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6645 case scAddRecExpr: {
6646 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6647 // If there's no unsigned wrap, the value will never be less than its
6649 if (AddRec->hasNoUnsignedWrap()) {
6650 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6651 if (!UnsignedMinValue.isZero())
6652 ConservativeResult = ConservativeResult.intersectWith(
6653 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6656 // If there's no signed wrap, and all the operands except initial value have
6657 // the same sign or zero, the value won't ever be:
6658 // 1: smaller than initial value if operands are non negative,
6659 // 2: bigger than initial value if operands are non positive.
6660 // For both cases, value can not cross signed min/max boundary.
6661 if (AddRec->hasNoSignedWrap()) {
6662 bool AllNonNeg = true;
6663 bool AllNonPos = true;
6664 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6665 if (!isKnownNonNegative(AddRec->getOperand(i)))
6667 if (!isKnownNonPositive(AddRec->getOperand(i)))
6671 ConservativeResult = ConservativeResult.intersectWith(
6672 ConstantRange::getNonEmpty(getSignedRangeMin(AddRec->getStart()),
6673 APInt::getSignedMinValue(BitWidth)),
6676 ConservativeResult = ConservativeResult.intersectWith(
6677 ConstantRange::getNonEmpty(APInt::getSignedMinValue(BitWidth),
6678 getSignedRangeMax(AddRec->getStart()) +
6683 // TODO: non-affine addrec
6684 if (AddRec->isAffine()) {
6685 const SCEV *MaxBEScev =
6686 getConstantMaxBackedgeTakenCount(AddRec->getLoop());
6687 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6688 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6690 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6691 // MaxBECount's active bits are all <= AddRec's bit width.
6692 if (MaxBECount.getBitWidth() > BitWidth &&
6693 MaxBECount.getActiveBits() <= BitWidth)
6694 MaxBECount = MaxBECount.trunc(BitWidth);
6695 else if (MaxBECount.getBitWidth() < BitWidth)
6696 MaxBECount = MaxBECount.zext(BitWidth);
6698 if (MaxBECount.getBitWidth() == BitWidth) {
6699 auto RangeFromAffine = getRangeForAffineAR(
6700 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6701 ConservativeResult =
6702 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6704 auto RangeFromFactoring = getRangeViaFactoring(
6705 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6706 ConservativeResult =
6707 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6711 // Now try symbolic BE count and more powerful methods.
6712 if (UseExpensiveRangeSharpening) {
6713 const SCEV *SymbolicMaxBECount =
6714 getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
6715 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6716 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6717 AddRec->hasNoSelfWrap()) {
6718 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6719 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6720 ConservativeResult =
6721 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6726 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6732 case scSequentialUMinExpr: {
6734 switch (S->getSCEVType()) {
6736 ID = Intrinsic::umax;
6739 ID = Intrinsic::smax;
6742 case scSequentialUMinExpr:
6743 ID = Intrinsic::umin;
6746 ID = Intrinsic::smin;
6749 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6752 const auto *NAry = cast<SCEVNAryExpr>(S);
6753 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6754 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6756 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6757 return setRange(S, SignHint,
6758 ConservativeResult.intersectWith(X, RangeType));
6761 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6762 Value *V = U->getValue();
6764 // Check if the IR explicitly contains !range metadata.
6765 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6767 ConservativeResult =
6768 ConservativeResult.intersectWith(*MDRange, RangeType);
6770 // Use facts about recurrences in the underlying IR. Note that add
6771 // recurrences are AddRecExprs and thus don't hit this path. This
6772 // primarily handles shift recurrences.
6773 auto CR = getRangeForUnknownRecurrence(U);
6774 ConservativeResult = ConservativeResult.intersectWith(CR);
6776 // See if ValueTracking can give us a useful range.
6777 const DataLayout &DL = getDataLayout();
6778 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6779 if (Known.getBitWidth() != BitWidth)
6780 Known = Known.zextOrTrunc(BitWidth);
6782 // ValueTracking may be able to compute a tighter result for the number of
6783 // sign bits than for the value of those sign bits.
6784 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6785 if (U->getType()->isPointerTy()) {
6786 // If the pointer size is larger than the index size type, this can cause
6787 // NS to be larger than BitWidth. So compensate for this.
6788 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6789 int ptrIdxDiff = ptrSize - BitWidth;
6790 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6795 // If we know any of the sign bits, we know all of the sign bits.
6796 if (!Known.Zero.getHiBits(NS).isZero())
6797 Known.Zero.setHighBits(NS);
6798 if (!Known.One.getHiBits(NS).isZero())
6799 Known.One.setHighBits(NS);
6802 if (Known.getMinValue() != Known.getMaxValue() + 1)
6803 ConservativeResult = ConservativeResult.intersectWith(
6804 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6807 ConservativeResult = ConservativeResult.intersectWith(
6808 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6809 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6812 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6813 // Strengthen the range if the underlying IR value is a
6814 // global/alloca/heap allocation using the size of the object.
6815 ObjectSizeOpts Opts;
6816 Opts.RoundToAlign = false;
6817 Opts.NullIsUnknownSize = true;
6819 if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) ||
6820 isAllocationFn(V, &TLI)) &&
6821 getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) {
6822 // The highest address the object can start is ObjSize bytes before the
6823 // end (unsigned max value). If this value is not a multiple of the
6824 // alignment, the last possible start value is the next lowest multiple
6825 // of the alignment. Note: The computations below cannot overflow,
6826 // because if they would there's no possible start address for the
6828 APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize);
6829 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6830 uint64_t Rem = MaxVal.urem(Align);
6831 MaxVal -= APInt(BitWidth, Rem);
6832 APInt MinVal = APInt::getZero(BitWidth);
6833 if (llvm::isKnownNonZero(V, DL))
6835 ConservativeResult = ConservativeResult.intersectWith(
6836 {MinVal, MaxVal + 1}, RangeType);
6840 // A range of Phi is a subset of union of all ranges of its input.
6841 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6842 // Make sure that we do not run over cycled Phis.
6843 if (PendingPhiRanges.insert(Phi).second) {
6844 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6846 for (const auto &Op : Phi->operands()) {
6847 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6848 RangeFromOps = RangeFromOps.unionWith(OpRange);
6849 // No point to continue if we already have a full set.
6850 if (RangeFromOps.isFullSet())
6853 ConservativeResult =
6854 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6855 bool Erased = PendingPhiRanges.erase(Phi);
6856 assert(Erased && "Failed to erase Phi properly?");
6861 // vscale can't be equal to zero
6862 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6863 if (II->getIntrinsicID() == Intrinsic::vscale) {
6864 ConstantRange Disallowed = APInt::getZero(BitWidth);
6865 ConservativeResult = ConservativeResult.difference(Disallowed);
6868 return setRange(U, SignHint, std::move(ConservativeResult));
6870 case scCouldNotCompute:
6871 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6874 return setRange(S, SignHint, std::move(ConservativeResult));
6877 // Given a StartRange, Step and MaxBECount for an expression compute a range of
6878 // values that the expression can take. Initially, the expression has a value
6879 // from StartRange and then is changed by Step up to MaxBECount times. Signed
6880 // argument defines if we treat Step as signed or unsigned.
6881 static ConstantRange getRangeForAffineARHelper(APInt Step,
6882 const ConstantRange &StartRange,
6883 const APInt &MaxBECount,
6885 unsigned BitWidth = Step.getBitWidth();
6886 assert(BitWidth == StartRange.getBitWidth() &&
6887 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6888 // If either Step or MaxBECount is 0, then the expression won't change, and we
6889 // just need to return the initial range.
6890 if (Step == 0 || MaxBECount == 0)
6893 // If we don't know anything about the initial value (i.e. StartRange is
6894 // FullRange), then we don't know anything about the final range either.
6895 // Return FullRange.
6896 if (StartRange.isFullSet())
6897 return ConstantRange::getFull(BitWidth);
6899 // If Step is signed and negative, then we use its absolute value, but we also
6900 // note that we're moving in the opposite direction.
6901 bool Descending = Signed && Step.isNegative();
6904 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6905 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6906 // This equations hold true due to the well-defined wrap-around behavior of
6910 // Check if Offset is more than full span of BitWidth. If it is, the
6911 // expression is guaranteed to overflow.
6912 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6913 return ConstantRange::getFull(BitWidth);
6915 // Offset is by how much the expression can change. Checks above guarantee no
6917 APInt Offset = Step * MaxBECount;
6919 // Minimum value of the final range will match the minimal value of StartRange
6920 // if the expression is increasing and will be decreased by Offset otherwise.
6921 // Maximum value of the final range will match the maximal value of StartRange
6922 // if the expression is decreasing and will be increased by Offset otherwise.
6923 APInt StartLower = StartRange.getLower();
6924 APInt StartUpper = StartRange.getUpper() - 1;
6925 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6926 : (StartUpper + std::move(Offset));
6928 // It's possible that the new minimum/maximum value will fall into the initial
6929 // range (due to wrap around). This means that the expression can take any
6930 // value in this bitwidth, and we have to return full range.
6931 if (StartRange.contains(MovedBoundary))
6932 return ConstantRange::getFull(BitWidth);
6935 Descending ? std::move(MovedBoundary) : std::move(StartLower);
6937 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
6940 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
6941 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
6944 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
6946 const APInt &MaxBECount) {
6947 assert(getTypeSizeInBits(Start->getType()) ==
6948 getTypeSizeInBits(Step->getType()) &&
6949 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
6950 "mismatched bit widths");
6952 // First, consider step signed.
6953 ConstantRange StartSRange = getSignedRange(Start);
6954 ConstantRange StepSRange = getSignedRange(Step);
6956 // If Step can be both positive and negative, we need to find ranges for the
6957 // maximum absolute step values in both directions and union them.
6958 ConstantRange SR = getRangeForAffineARHelper(
6959 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
6960 SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
6961 StartSRange, MaxBECount,
6962 /* Signed = */ true));
6964 // Next, consider step unsigned.
6965 ConstantRange UR = getRangeForAffineARHelper(
6966 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
6967 /* Signed = */ false);
6969 // Finally, intersect signed and unsigned ranges.
6970 return SR.intersectWith(UR, ConstantRange::Smallest);
6973 ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
6974 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
6975 ScalarEvolution::RangeSignHint SignHint) {
6976 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
6977 assert(AddRec->hasNoSelfWrap() &&
6978 "This only works for non-self-wrapping AddRecs!");
6979 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
6980 const SCEV *Step = AddRec->getStepRecurrence(*this);
6981 // Only deal with constant step to save compile time.
6982 if (!isa<SCEVConstant>(Step))
6983 return ConstantRange::getFull(BitWidth);
6984 // Let's make sure that we can prove that we do not self-wrap during
6985 // MaxBECount iterations. We need this because MaxBECount is a maximum
6986 // iteration count estimate, and we might infer nw from some exit for which we
6987 // do not know max exit count (or any other side reasoning).
6988 // TODO: Turn into assert at some point.
6989 if (getTypeSizeInBits(MaxBECount->getType()) >
6990 getTypeSizeInBits(AddRec->getType()))
6991 return ConstantRange::getFull(BitWidth);
6992 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
6993 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
6994 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
6995 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
6996 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
6997 MaxItersWithoutWrap))
6998 return ConstantRange::getFull(BitWidth);
7000 ICmpInst::Predicate LEPred =
7001 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
7002 ICmpInst::Predicate GEPred =
7003 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
7004 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7006 // We know that there is no self-wrap. Let's take Start and End values and
7007 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7008 // the iteration. They either lie inside the range [Min(Start, End),
7009 // Max(Start, End)] or outside it:
7011 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7012 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7014 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7015 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7016 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7017 // Start <= End and step is positive, or Start >= End and step is negative.
7018 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7019 ConstantRange StartRange = getRangeRef(Start, SignHint);
7020 ConstantRange EndRange = getRangeRef(End, SignHint);
7021 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7022 // If they already cover full iteration space, we will know nothing useful
7023 // even if we prove what we want to prove.
7024 if (RangeBetween.isFullSet())
7025 return RangeBetween;
7026 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7027 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7028 : RangeBetween.isWrappedSet();
7030 return ConstantRange::getFull(BitWidth);
7032 if (isKnownPositive(Step) &&
7033 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7034 return RangeBetween;
7035 if (isKnownNegative(Step) &&
7036 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7037 return RangeBetween;
7038 return ConstantRange::getFull(BitWidth);
7041 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7043 const APInt &MaxBECount) {
7044 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7045 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7047 unsigned BitWidth = MaxBECount.getBitWidth();
7048 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7049 getTypeSizeInBits(Step->getType()) == BitWidth &&
7050 "mismatched bit widths");
7052 struct SelectPattern {
7053 Value *Condition = nullptr;
7057 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7059 std::optional<unsigned> CastOp;
7060 APInt Offset(BitWidth, 0);
7062 assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
7065 // Peel off a constant offset:
7066 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7067 // In the future we could consider being smarter here and handle
7068 // {Start+Step,+,Step} too.
7069 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7072 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7073 S = SA->getOperand(1);
7076 // Peel off a cast operation
7077 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7078 CastOp = SCast->getSCEVType();
7079 S = SCast->getOperand();
7082 using namespace llvm::PatternMatch;
7084 auto *SU = dyn_cast<SCEVUnknown>(S);
7085 const APInt *TrueVal, *FalseVal;
7087 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7088 m_APInt(FalseVal)))) {
7089 Condition = nullptr;
7093 TrueValue = *TrueVal;
7094 FalseValue = *FalseVal;
7096 // Re-apply the cast we peeled off earlier
7100 llvm_unreachable("Unknown SCEV cast type!");
7103 TrueValue = TrueValue.trunc(BitWidth);
7104 FalseValue = FalseValue.trunc(BitWidth);
7107 TrueValue = TrueValue.zext(BitWidth);
7108 FalseValue = FalseValue.zext(BitWidth);
7111 TrueValue = TrueValue.sext(BitWidth);
7112 FalseValue = FalseValue.sext(BitWidth);
7116 // Re-apply the constant offset we peeled off earlier
7117 TrueValue += Offset;
7118 FalseValue += Offset;
7121 bool isRecognized() { return Condition != nullptr; }
7124 SelectPattern StartPattern(*this, BitWidth, Start);
7125 if (!StartPattern.isRecognized())
7126 return ConstantRange::getFull(BitWidth);
7128 SelectPattern StepPattern(*this, BitWidth, Step);
7129 if (!StepPattern.isRecognized())
7130 return ConstantRange::getFull(BitWidth);
7132 if (StartPattern.Condition != StepPattern.Condition) {
7133 // We don't handle this case today; but we could, by considering four
7134 // possibilities below instead of two. I'm not sure if there are cases where
7135 // that will help over what getRange already does, though.
7136 return ConstantRange::getFull(BitWidth);
7139 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7140 // construct arbitrary general SCEV expressions here. This function is called
7141 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7142 // say) can end up caching a suboptimal value.
7144 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7145 // C2352 and C2512 (otherwise it isn't needed).
7147 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7148 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7149 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7150 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7152 ConstantRange TrueRange =
7153 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7154 ConstantRange FalseRange =
7155 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7157 return TrueRange.unionWith(FalseRange);
7160 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7161 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7162 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7164 // Return early if there are no flags to propagate to the SCEV.
7165 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
7166 if (BinOp->hasNoUnsignedWrap())
7167 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
7168 if (BinOp->hasNoSignedWrap())
7169 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
7170 if (Flags == SCEV::FlagAnyWrap)
7171 return SCEV::FlagAnyWrap;
7173 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7177 ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7178 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7179 return &*AddRec->getLoop()->getHeader()->begin();
7180 if (auto *U = dyn_cast<SCEVUnknown>(S))
7181 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7187 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7190 // Do a bounded search of the def relation of the requested SCEVs.
7191 SmallSet<const SCEV *, 16> Visited;
7192 SmallVector<const SCEV *> Worklist;
7193 auto pushOp = [&](const SCEV *S) {
7194 if (!Visited.insert(S).second)
7196 // Threshold of 30 here is arbitrary.
7197 if (Visited.size() > 30) {
7201 Worklist.push_back(S);
7204 for (const auto *S : Ops)
7207 const Instruction *Bound = nullptr;
7208 while (!Worklist.empty()) {
7209 auto *S = Worklist.pop_back_val();
7210 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7211 if (!Bound || DT.dominates(Bound, DefI))
7214 for (const auto *Op : S->operands())
7218 return Bound ? Bound : &*F.getEntryBlock().begin();
7222 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7224 return getDefiningScopeBound(Ops, Discard);
7227 bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7228 const Instruction *B) {
7229 if (A->getParent() == B->getParent() &&
7230 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
7234 auto *BLoop = LI.getLoopFor(B->getParent());
7235 if (BLoop && BLoop->getHeader() == B->getParent() &&
7236 BLoop->getLoopPreheader() == A->getParent() &&
7237 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
7238 A->getParent()->end()) &&
7239 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7246 bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7247 // Only proceed if we can prove that I does not yield poison.
7248 if (!programUndefinedIfPoison(I))
7251 // At this point we know that if I is executed, then it does not wrap
7252 // according to at least one of NSW or NUW. If I is not executed, then we do
7253 // not know if the calculation that I represents would wrap. Multiple
7254 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7255 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7256 // derived from other instructions that map to the same SCEV. We cannot make
7257 // that guarantee for cases where I is not executed. So we need to find a
7258 // upper bound on the defining scope for the SCEV, and prove that I is
7259 // executed every time we enter that scope. When the bounding scope is a
7260 // loop (the common case), this is equivalent to proving I executes on every
7261 // iteration of that loop.
7262 SmallVector<const SCEV *> SCEVOps;
7263 for (const Use &Op : I->operands()) {
7264 // I could be an extractvalue from a call to an overflow intrinsic.
7265 // TODO: We can do better here in some cases.
7266 if (isSCEVable(Op->getType()))
7267 SCEVOps.push_back(getSCEV(Op));
7269 auto *DefI = getDefiningScopeBound(SCEVOps);
7270 return isGuaranteedToTransferExecutionTo(DefI, I);
7273 bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7274 // If we know that \c I can never be poison period, then that's enough.
7275 if (isSCEVExprNeverPoison(I))
7278 // If the loop only has one exit, then we know that, if the loop is entered,
7279 // any instruction dominating that exit will be executed. If any such
7280 // instruction would result in UB, the addrec cannot be poison.
7282 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7283 // also handles uses outside the loop header (they just need to dominate the
7286 auto *ExitingBB = L->getExitingBlock();
7287 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7290 SmallPtrSet<const Value *, 16> KnownPoison;
7291 SmallVector<const Instruction *, 8> Worklist;
7293 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7294 // things that are known to be poison under that assumption go on the
7296 KnownPoison.insert(I);
7297 Worklist.push_back(I);
7299 while (!Worklist.empty()) {
7300 const Instruction *Poison = Worklist.pop_back_val();
7302 for (const Use &U : Poison->uses()) {
7303 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7304 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7305 DT.dominates(PoisonUser->getParent(), ExitingBB))
7308 if (propagatesPoison(U) && L->contains(PoisonUser))
7309 if (KnownPoison.insert(PoisonUser).second)
7310 Worklist.push_back(PoisonUser);
7317 ScalarEvolution::LoopProperties
7318 ScalarEvolution::getLoopProperties(const Loop *L) {
7319 using LoopProperties = ScalarEvolution::LoopProperties;
7321 auto Itr = LoopPropertiesCache.find(L);
7322 if (Itr == LoopPropertiesCache.end()) {
7323 auto HasSideEffects = [](Instruction *I) {
7324 if (auto *SI = dyn_cast<StoreInst>(I))
7325 return !SI->isSimple();
7327 return I->mayThrow() || I->mayWriteToMemory();
7330 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7331 /*HasNoSideEffects*/ true};
7333 for (auto *BB : L->getBlocks())
7334 for (auto &I : *BB) {
7335 if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7336 LP.HasNoAbnormalExits = false;
7337 if (HasSideEffects(&I))
7338 LP.HasNoSideEffects = false;
7339 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7340 break; // We're already as pessimistic as we can get.
7343 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7344 assert(InsertPair.second && "We just checked!");
7345 Itr = InsertPair.first;
7351 bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
7352 // A mustprogress loop without side effects must be finite.
7353 // TODO: The check used here is very conservative. It's only *specific*
7354 // side effects which are well defined in infinite loops.
7355 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7358 const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7359 // Worklist item with a Value and a bool indicating whether all operands have
7360 // been visited already.
7361 using PointerTy = PointerIntPair<Value *, 1, bool>;
7362 SmallVector<PointerTy> Stack;
7364 Stack.emplace_back(V, true);
7365 Stack.emplace_back(V, false);
7366 while (!Stack.empty()) {
7367 auto E = Stack.pop_back_val();
7368 Value *CurV = E.getPointer();
7370 if (getExistingSCEV(CurV))
7373 SmallVector<Value *> Ops;
7374 const SCEV *CreatedSCEV = nullptr;
7375 // If all operands have been visited already, create the SCEV.
7377 CreatedSCEV = createSCEV(CurV);
7379 // Otherwise get the operands we need to create SCEV's for before creating
7380 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7382 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7386 insertValueToMap(CurV, CreatedSCEV);
7388 // Queue CurV for SCEV creation, followed by its's operands which need to
7389 // be constructed first.
7390 Stack.emplace_back(CurV, true);
7391 for (Value *Op : Ops)
7392 Stack.emplace_back(Op, false);
7396 return getExistingSCEV(V);
7400 ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7401 if (!isSCEVable(V->getType()))
7402 return getUnknown(V);
7404 if (Instruction *I = dyn_cast<Instruction>(V)) {
7405 // Don't attempt to analyze instructions in blocks that aren't
7406 // reachable. Such instructions don't matter, and they aren't required
7407 // to obey basic rules for definitions dominating uses which this
7408 // analysis depends on.
7409 if (!DT.isReachableFromEntry(I->getParent()))
7410 return getUnknown(PoisonValue::get(V->getType()));
7411 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7412 return getConstant(CI);
7413 else if (isa<GlobalAlias>(V))
7414 return getUnknown(V);
7415 else if (!isa<ConstantExpr>(V))
7416 return getUnknown(V);
7418 Operator *U = cast<Operator>(V);
7420 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7421 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7422 switch (BO->Opcode) {
7423 case Instruction::Add:
7424 case Instruction::Mul: {
7425 // For additions and multiplications, traverse add/mul chains for which we
7426 // can potentially create a single SCEV, to reduce the number of
7427 // get{Add,Mul}Expr calls.
7430 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7431 Ops.push_back(BO->Op);
7435 Ops.push_back(BO->RHS);
7436 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7437 dyn_cast<Instruction>(V));
7439 (BO->Opcode == Instruction::Add &&
7440 (NewBO->Opcode != Instruction::Add &&
7441 NewBO->Opcode != Instruction::Sub)) ||
7442 (BO->Opcode == Instruction::Mul &&
7443 NewBO->Opcode != Instruction::Mul)) {
7444 Ops.push_back(BO->LHS);
7447 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7448 // requires a SCEV for the LHS.
7449 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7450 auto *I = dyn_cast<Instruction>(BO->Op);
7451 if (I && programUndefinedIfPoison(I)) {
7452 Ops.push_back(BO->LHS);
7460 case Instruction::Sub:
7461 case Instruction::UDiv:
7462 case Instruction::URem:
7464 case Instruction::AShr:
7465 case Instruction::Shl:
7466 case Instruction::Xor:
7470 case Instruction::And:
7471 case Instruction::Or:
7472 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7475 case Instruction::LShr:
7476 return getUnknown(V);
7478 llvm_unreachable("Unhandled binop");
7482 Ops.push_back(BO->LHS);
7483 Ops.push_back(BO->RHS);
7487 switch (U->getOpcode()) {
7488 case Instruction::Trunc:
7489 case Instruction::ZExt:
7490 case Instruction::SExt:
7491 case Instruction::PtrToInt:
7492 Ops.push_back(U->getOperand(0));
7495 case Instruction::BitCast:
7496 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7497 Ops.push_back(U->getOperand(0));
7500 return getUnknown(V);
7502 case Instruction::SDiv:
7503 case Instruction::SRem:
7504 Ops.push_back(U->getOperand(0));
7505 Ops.push_back(U->getOperand(1));
7508 case Instruction::GetElementPtr:
7509 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7510 "GEP source element type must be sized");
7511 for (Value *Index : U->operands())
7512 Ops.push_back(Index);
7515 case Instruction::IntToPtr:
7516 return getUnknown(V);
7518 case Instruction::PHI:
7519 // Keep constructing SCEVs' for phis recursively for now.
7522 case Instruction::Select: {
7523 // Check if U is a select that can be simplified to a SCEVUnknown.
7524 auto CanSimplifyToUnknown = [this, U]() {
7525 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7528 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7531 Value *LHS = ICI->getOperand(0);
7532 Value *RHS = ICI->getOperand(1);
7533 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7534 ICI->getPredicate() == CmpInst::ICMP_NE) {
7535 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7537 } else if (getTypeSizeInBits(LHS->getType()) >
7538 getTypeSizeInBits(U->getType()))
7542 if (CanSimplifyToUnknown())
7543 return getUnknown(U);
7545 for (Value *Inc : U->operands())
7550 case Instruction::Call:
7551 case Instruction::Invoke:
7552 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7557 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7558 switch (II->getIntrinsicID()) {
7559 case Intrinsic::abs:
7560 Ops.push_back(II->getArgOperand(0));
7562 case Intrinsic::umax:
7563 case Intrinsic::umin:
7564 case Intrinsic::smax:
7565 case Intrinsic::smin:
7566 case Intrinsic::usub_sat:
7567 case Intrinsic::uadd_sat:
7568 Ops.push_back(II->getArgOperand(0));
7569 Ops.push_back(II->getArgOperand(1));
7571 case Intrinsic::start_loop_iterations:
7572 case Intrinsic::annotation:
7573 case Intrinsic::ptr_annotation:
7574 Ops.push_back(II->getArgOperand(0));
7586 const SCEV *ScalarEvolution::createSCEV(Value *V) {
7587 if (!isSCEVable(V->getType()))
7588 return getUnknown(V);
7590 if (Instruction *I = dyn_cast<Instruction>(V)) {
7591 // Don't attempt to analyze instructions in blocks that aren't
7592 // reachable. Such instructions don't matter, and they aren't required
7593 // to obey basic rules for definitions dominating uses which this
7594 // analysis depends on.
7595 if (!DT.isReachableFromEntry(I->getParent()))
7596 return getUnknown(PoisonValue::get(V->getType()));
7597 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7598 return getConstant(CI);
7599 else if (isa<GlobalAlias>(V))
7600 return getUnknown(V);
7601 else if (!isa<ConstantExpr>(V))
7602 return getUnknown(V);
7607 Operator *U = cast<Operator>(V);
7609 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7610 switch (BO->Opcode) {
7611 case Instruction::Add: {
7612 // The simple thing to do would be to just call getSCEV on both operands
7613 // and call getAddExpr with the result. However if we're looking at a
7614 // bunch of things all added together, this can be quite inefficient,
7615 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7616 // Instead, gather up all the operands and make a single getAddExpr call.
7617 // LLVM IR canonical form means we need only traverse the left operands.
7618 SmallVector<const SCEV *, 4> AddOps;
7621 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7622 AddOps.push_back(OpSCEV);
7626 // If a NUW or NSW flag can be applied to the SCEV for this
7627 // addition, then compute the SCEV for this addition by itself
7628 // with a separate call to getAddExpr. We need to do that
7629 // instead of pushing the operands of the addition onto AddOps,
7630 // since the flags are only known to apply to this particular
7631 // addition - they may not apply to other additions that can be
7632 // formed with operands from AddOps.
7633 const SCEV *RHS = getSCEV(BO->RHS);
7634 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7635 if (Flags != SCEV::FlagAnyWrap) {
7636 const SCEV *LHS = getSCEV(BO->LHS);
7637 if (BO->Opcode == Instruction::Sub)
7638 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7640 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7645 if (BO->Opcode == Instruction::Sub)
7646 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7648 AddOps.push_back(getSCEV(BO->RHS));
7650 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7651 dyn_cast<Instruction>(V));
7652 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7653 NewBO->Opcode != Instruction::Sub)) {
7654 AddOps.push_back(getSCEV(BO->LHS));
7660 return getAddExpr(AddOps);
7663 case Instruction::Mul: {
7664 SmallVector<const SCEV *, 4> MulOps;
7667 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7668 MulOps.push_back(OpSCEV);
7672 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7673 if (Flags != SCEV::FlagAnyWrap) {
7674 LHS = getSCEV(BO->LHS);
7675 RHS = getSCEV(BO->RHS);
7676 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7681 MulOps.push_back(getSCEV(BO->RHS));
7682 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7683 dyn_cast<Instruction>(V));
7684 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7685 MulOps.push_back(getSCEV(BO->LHS));
7691 return getMulExpr(MulOps);
7693 case Instruction::UDiv:
7694 LHS = getSCEV(BO->LHS);
7695 RHS = getSCEV(BO->RHS);
7696 return getUDivExpr(LHS, RHS);
7697 case Instruction::URem:
7698 LHS = getSCEV(BO->LHS);
7699 RHS = getSCEV(BO->RHS);
7700 return getURemExpr(LHS, RHS);
7701 case Instruction::Sub: {
7702 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
7704 Flags = getNoWrapFlagsFromUB(BO->Op);
7705 LHS = getSCEV(BO->LHS);
7706 RHS = getSCEV(BO->RHS);
7707 return getMinusSCEV(LHS, RHS, Flags);
7709 case Instruction::And:
7710 // For an expression like x&255 that merely masks off the high bits,
7711 // use zext(trunc(x)) as the SCEV expression.
7712 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7714 return getSCEV(BO->RHS);
7715 if (CI->isMinusOne())
7716 return getSCEV(BO->LHS);
7717 const APInt &A = CI->getValue();
7719 // Instcombine's ShrinkDemandedConstant may strip bits out of
7720 // constants, obscuring what would otherwise be a low-bits mask.
7721 // Use computeKnownBits to compute what ShrinkDemandedConstant
7722 // knew about to reconstruct a low-bits mask value.
7723 unsigned LZ = A.countl_zero();
7724 unsigned TZ = A.countr_zero();
7725 unsigned BitWidth = A.getBitWidth();
7726 KnownBits Known(BitWidth);
7727 computeKnownBits(BO->LHS, Known, getDataLayout(),
7728 0, &AC, nullptr, &DT);
7730 APInt EffectiveMask =
7731 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7732 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7733 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7734 const SCEV *LHS = getSCEV(BO->LHS);
7735 const SCEV *ShiftedLHS = nullptr;
7736 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7737 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7738 // For an expression like (x * 8) & 8, simplify the multiply.
7739 unsigned MulZeros = OpC->getAPInt().countr_zero();
7740 unsigned GCD = std::min(MulZeros, TZ);
7741 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7742 SmallVector<const SCEV*, 4> MulOps;
7743 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7744 append_range(MulOps, LHSMul->operands().drop_front());
7745 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7746 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7750 ShiftedLHS = getUDivExpr(LHS, MulCount);
7753 getTruncateExpr(ShiftedLHS,
7754 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7755 BO->LHS->getType()),
7759 // Binary `and` is a bit-wise `umin`.
7760 if (BO->LHS->getType()->isIntegerTy(1)) {
7761 LHS = getSCEV(BO->LHS);
7762 RHS = getSCEV(BO->RHS);
7763 return getUMinExpr(LHS, RHS);
7767 case Instruction::Or:
7768 // Binary `or` is a bit-wise `umax`.
7769 if (BO->LHS->getType()->isIntegerTy(1)) {
7770 LHS = getSCEV(BO->LHS);
7771 RHS = getSCEV(BO->RHS);
7772 return getUMaxExpr(LHS, RHS);
7776 case Instruction::Xor:
7777 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7778 // If the RHS of xor is -1, then this is a not operation.
7779 if (CI->isMinusOne())
7780 return getNotSCEV(getSCEV(BO->LHS));
7782 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7783 // This is a variant of the check for xor with -1, and it handles
7784 // the case where instcombine has trimmed non-demanded bits out
7785 // of an xor with -1.
7786 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7787 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7788 if (LBO->getOpcode() == Instruction::And &&
7789 LCI->getValue() == CI->getValue())
7790 if (const SCEVZeroExtendExpr *Z =
7791 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7792 Type *UTy = BO->LHS->getType();
7793 const SCEV *Z0 = Z->getOperand();
7794 Type *Z0Ty = Z0->getType();
7795 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7797 // If C is a low-bits mask, the zero extend is serving to
7798 // mask off the high bits. Complement the operand and
7799 // re-apply the zext.
7800 if (CI->getValue().isMask(Z0TySize))
7801 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7803 // If C is a single bit, it may be in the sign-bit position
7804 // before the zero-extend. In this case, represent the xor
7805 // using an add, which is equivalent, and re-apply the zext.
7806 APInt Trunc = CI->getValue().trunc(Z0TySize);
7807 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7809 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7815 case Instruction::Shl:
7816 // Turn shift left of a constant amount into a multiply.
7817 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7818 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7820 // If the shift count is not less than the bitwidth, the result of
7821 // the shift is undefined. Don't try to analyze it, because the
7822 // resolution chosen here may differ from the resolution chosen in
7823 // other parts of the compiler.
7824 if (SA->getValue().uge(BitWidth))
7827 // We can safely preserve the nuw flag in all cases. It's also safe to
7828 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7829 // requires special handling. It can be preserved as long as we're not
7830 // left shifting by bitwidth - 1.
7831 auto Flags = SCEV::FlagAnyWrap;
7833 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7834 if ((MulFlags & SCEV::FlagNSW) &&
7835 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7836 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW);
7837 if (MulFlags & SCEV::FlagNUW)
7838 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW);
7841 ConstantInt *X = ConstantInt::get(
7842 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7843 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7847 case Instruction::AShr: {
7848 // AShr X, C, where C is a constant.
7849 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7853 Type *OuterTy = BO->LHS->getType();
7854 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
7855 // If the shift count is not less than the bitwidth, the result of
7856 // the shift is undefined. Don't try to analyze it, because the
7857 // resolution chosen here may differ from the resolution chosen in
7858 // other parts of the compiler.
7859 if (CI->getValue().uge(BitWidth))
7863 return getSCEV(BO->LHS); // shift by zero --> noop
7865 uint64_t AShrAmt = CI->getZExtValue();
7866 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7868 Operator *L = dyn_cast<Operator>(BO->LHS);
7869 if (L && L->getOpcode() == Instruction::Shl) {
7872 // Both n and m are constant.
7874 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7875 if (L->getOperand(1) == BO->RHS)
7876 // For a two-shift sext-inreg, i.e. n = m,
7877 // use sext(trunc(x)) as the SCEV expression.
7878 return getSignExtendExpr(
7879 getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
7881 ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7882 if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
7883 uint64_t ShlAmt = ShlAmtCI->getZExtValue();
7884 if (ShlAmt > AShrAmt) {
7885 // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7886 // expression. We already checked that ShlAmt < BitWidth, so
7887 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7888 // ShlAmt - AShrAmt < Amt.
7889 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
7891 return getSignExtendExpr(
7892 getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
7893 getConstant(Mul)), OuterTy);
7902 switch (U->getOpcode()) {
7903 case Instruction::Trunc:
7904 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
7906 case Instruction::ZExt:
7907 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7909 case Instruction::SExt:
7910 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
7911 dyn_cast<Instruction>(V))) {
7912 // The NSW flag of a subtract does not always survive the conversion to
7913 // A + (-1)*B. By pushing sign extension onto its operands we are much
7914 // more likely to preserve NSW and allow later AddRec optimisations.
7916 // NOTE: This is effectively duplicating this logic from getSignExtend:
7917 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
7918 // but by that point the NSW information has potentially been lost.
7919 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
7920 Type *Ty = U->getType();
7921 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
7922 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
7923 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
7926 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7928 case Instruction::BitCast:
7929 // BitCasts are no-op casts so we just eliminate the cast.
7930 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
7931 return getSCEV(U->getOperand(0));
7934 case Instruction::PtrToInt: {
7935 // Pointer to integer cast is straight-forward, so do model it.
7936 const SCEV *Op = getSCEV(U->getOperand(0));
7937 Type *DstIntTy = U->getType();
7938 // But only if effective SCEV (integer) type is wide enough to represent
7939 // all possible pointer values.
7940 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
7941 if (isa<SCEVCouldNotCompute>(IntOp))
7942 return getUnknown(V);
7945 case Instruction::IntToPtr:
7946 // Just don't deal with inttoptr casts.
7947 return getUnknown(V);
7949 case Instruction::SDiv:
7950 // If both operands are non-negative, this is just an udiv.
7951 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7952 isKnownNonNegative(getSCEV(U->getOperand(1))))
7953 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7956 case Instruction::SRem:
7957 // If both operands are non-negative, this is just an urem.
7958 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7959 isKnownNonNegative(getSCEV(U->getOperand(1))))
7960 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7963 case Instruction::GetElementPtr:
7964 return createNodeForGEP(cast<GEPOperator>(U));
7966 case Instruction::PHI:
7967 return createNodeForPHI(cast<PHINode>(U));
7969 case Instruction::Select:
7970 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
7973 case Instruction::Call:
7974 case Instruction::Invoke:
7975 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
7978 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7979 switch (II->getIntrinsicID()) {
7980 case Intrinsic::abs:
7982 getSCEV(II->getArgOperand(0)),
7983 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
7984 case Intrinsic::umax:
7985 LHS = getSCEV(II->getArgOperand(0));
7986 RHS = getSCEV(II->getArgOperand(1));
7987 return getUMaxExpr(LHS, RHS);
7988 case Intrinsic::umin:
7989 LHS = getSCEV(II->getArgOperand(0));
7990 RHS = getSCEV(II->getArgOperand(1));
7991 return getUMinExpr(LHS, RHS);
7992 case Intrinsic::smax:
7993 LHS = getSCEV(II->getArgOperand(0));
7994 RHS = getSCEV(II->getArgOperand(1));
7995 return getSMaxExpr(LHS, RHS);
7996 case Intrinsic::smin:
7997 LHS = getSCEV(II->getArgOperand(0));
7998 RHS = getSCEV(II->getArgOperand(1));
7999 return getSMinExpr(LHS, RHS);
8000 case Intrinsic::usub_sat: {
8001 const SCEV *X = getSCEV(II->getArgOperand(0));
8002 const SCEV *Y = getSCEV(II->getArgOperand(1));
8003 const SCEV *ClampedY = getUMinExpr(X, Y);
8004 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8006 case Intrinsic::uadd_sat: {
8007 const SCEV *X = getSCEV(II->getArgOperand(0));
8008 const SCEV *Y = getSCEV(II->getArgOperand(1));
8009 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8010 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8012 case Intrinsic::start_loop_iterations:
8013 case Intrinsic::annotation:
8014 case Intrinsic::ptr_annotation:
8015 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8016 // just eqivalent to the first operand for SCEV purposes.
8017 return getSCEV(II->getArgOperand(0));
8018 case Intrinsic::vscale:
8019 return getVScale(II->getType());
8027 return getUnknown(V);
8030 //===----------------------------------------------------------------------===//
8031 // Iteration Count Computation Code
8034 const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) {
8035 if (isa<SCEVCouldNotCompute>(ExitCount))
8036 return getCouldNotCompute();
8038 auto *ExitCountType = ExitCount->getType();
8039 assert(ExitCountType->isIntegerTy());
8040 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8041 1 + ExitCountType->getScalarSizeInBits());
8042 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8045 const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
8048 if (isa<SCEVCouldNotCompute>(ExitCount))
8049 return getCouldNotCompute();
8051 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8052 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8054 auto CanAddOneWithoutOverflow = [&]() {
8055 ConstantRange ExitCountRange =
8056 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8057 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8060 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8061 getMinusOne(ExitCount->getType()));
8064 // If we need to zero extend the backedge count, check if we can add one to
8065 // it prior to zero extending without overflow. Provided this is safe, it
8066 // allows better simplification of the +1.
8067 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8068 return getZeroExtendExpr(
8069 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8071 // Get the total trip count from the count by adding 1. This may wrap.
8072 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8075 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8079 ConstantInt *ExitConst = ExitCount->getValue();
8081 // Guard against huge trip counts.
8082 if (ExitConst->getValue().getActiveBits() > 32)
8085 // In case of integer overflow, this returns 0, which is correct.
8086 return ((unsigned)ExitConst->getZExtValue()) + 1;
8089 unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
8090 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8091 return getConstantTripCount(ExitCount);
8095 ScalarEvolution::getSmallConstantTripCount(const Loop *L,
8096 const BasicBlock *ExitingBlock) {
8097 assert(ExitingBlock && "Must pass a non-null exiting block!");
8098 assert(L->isLoopExiting(ExitingBlock) &&
8099 "Exiting block must actually branch out of the loop!");
8100 const SCEVConstant *ExitCount =
8101 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8102 return getConstantTripCount(ExitCount);
8105 unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
8106 const auto *MaxExitCount =
8107 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8108 return getConstantTripCount(MaxExitCount);
8111 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
8112 SmallVector<BasicBlock *, 8> ExitingBlocks;
8113 L->getExitingBlocks(ExitingBlocks);
8115 std::optional<unsigned> Res;
8116 for (auto *ExitingBB : ExitingBlocks) {
8117 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8120 Res = (unsigned)std::gcd(*Res, Multiple);
8122 return Res.value_or(1);
8125 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
8126 const SCEV *ExitCount) {
8127 if (ExitCount == getCouldNotCompute())
8130 // Get the trip count
8131 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8133 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8134 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8135 // the greatest power of 2 divisor less than 2^32.
8136 return Multiple.getActiveBits() > 32
8137 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8138 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8141 /// Returns the largest constant divisor of the trip count of this loop as a
8142 /// normal unsigned value, if possible. This means that the actual trip count is
8143 /// always a multiple of the returned value (don't forget the trip count could
8144 /// very well be zero as well!).
8146 /// Returns 1 if the trip count is unknown or not guaranteed to be the
8147 /// multiple of a constant (which is also the case if the trip count is simply
8148 /// constant, use getSmallConstantTripCount for that case), Will also return 1
8149 /// if the trip count is very large (>= 2^32).
8151 /// As explained in the comments for getSmallConstantTripCount, this assumes
8152 /// that control exits the loop via ExitingBlock.
8154 ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
8155 const BasicBlock *ExitingBlock) {
8156 assert(ExitingBlock && "Must pass a non-null exiting block!");
8157 assert(L->isLoopExiting(ExitingBlock) &&
8158 "Exiting block must actually branch out of the loop!");
8159 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8160 return getSmallConstantTripMultiple(L, ExitCount);
8163 const SCEV *ScalarEvolution::getExitCount(const Loop *L,
8164 const BasicBlock *ExitingBlock,
8165 ExitCountKind Kind) {
8168 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8169 case SymbolicMaximum:
8170 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8171 case ConstantMaximum:
8172 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8174 llvm_unreachable("Invalid ExitCountKind!");
8178 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
8179 SmallVector<const SCEVPredicate *, 4> &Preds) {
8180 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8183 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
8184 ExitCountKind Kind) {
8187 return getBackedgeTakenInfo(L).getExact(L, this);
8188 case ConstantMaximum:
8189 return getBackedgeTakenInfo(L).getConstantMax(this);
8190 case SymbolicMaximum:
8191 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8193 llvm_unreachable("Invalid ExitCountKind!");
8196 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
8197 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8200 /// Push PHI nodes in the header of the given loop onto the given Worklist.
8201 static void PushLoopPHIs(const Loop *L,
8202 SmallVectorImpl<Instruction *> &Worklist,
8203 SmallPtrSetImpl<Instruction *> &Visited) {
8204 BasicBlock *Header = L->getHeader();
8206 // Push all Loop-header PHIs onto the Worklist stack.
8207 for (PHINode &PN : Header->phis())
8208 if (Visited.insert(&PN).second)
8209 Worklist.push_back(&PN);
8212 const ScalarEvolution::BackedgeTakenInfo &
8213 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8214 auto &BTI = getBackedgeTakenInfo(L);
8215 if (BTI.hasFullInfo())
8218 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8221 return Pair.first->second;
8223 BackedgeTakenInfo Result =
8224 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8226 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8229 ScalarEvolution::BackedgeTakenInfo &
8230 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8231 // Initially insert an invalid entry for this loop. If the insertion
8232 // succeeds, proceed to actually compute a backedge-taken count and
8233 // update the value. The temporary CouldNotCompute value tells SCEV
8234 // code elsewhere that it shouldn't attempt to request a new
8235 // backedge-taken count, which could result in infinite recursion.
8236 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8237 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8239 return Pair.first->second;
8241 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8242 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8243 // must be cleared in this scope.
8244 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8246 // Now that we know more about the trip count for this loop, forget any
8247 // existing SCEV values for PHI nodes in this loop since they are only
8248 // conservative estimates made without the benefit of trip count
8249 // information. This invalidation is not necessary for correctness, and is
8250 // only done to produce more precise results.
8251 if (Result.hasAnyInfo()) {
8252 // Invalidate any expression using an addrec in this loop.
8253 SmallVector<const SCEV *, 8> ToForget;
8254 auto LoopUsersIt = LoopUsers.find(L);
8255 if (LoopUsersIt != LoopUsers.end())
8256 append_range(ToForget, LoopUsersIt->second);
8257 forgetMemoizedResults(ToForget);
8259 // Invalidate constant-evolved loop header phis.
8260 for (PHINode &PN : L->getHeader()->phis())
8261 ConstantEvolutionLoopExitValue.erase(&PN);
8264 // Re-lookup the insert position, since the call to
8265 // computeBackedgeTakenCount above could result in a
8266 // recusive call to getBackedgeTakenInfo (on a different
8267 // loop), which would invalidate the iterator computed
8269 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8272 void ScalarEvolution::forgetAllLoops() {
8273 // This method is intended to forget all info about loops. It should
8274 // invalidate caches as if the following happened:
8275 // - The trip counts of all loops have changed arbitrarily
8276 // - Every llvm::Value has been updated in place to produce a different
8278 BackedgeTakenCounts.clear();
8279 PredicatedBackedgeTakenCounts.clear();
8280 BECountUsers.clear();
8281 LoopPropertiesCache.clear();
8282 ConstantEvolutionLoopExitValue.clear();
8283 ValueExprMap.clear();
8284 ValuesAtScopes.clear();
8285 ValuesAtScopesUsers.clear();
8286 LoopDispositions.clear();
8287 BlockDispositions.clear();
8288 UnsignedRanges.clear();
8289 SignedRanges.clear();
8290 ExprValueMap.clear();
8292 ConstantMultipleCache.clear();
8293 PredicatedSCEVRewrites.clear();
8295 FoldCacheUser.clear();
8297 void ScalarEvolution::visitAndClearUsers(
8298 SmallVectorImpl<Instruction *> &Worklist,
8299 SmallPtrSetImpl<Instruction *> &Visited,
8300 SmallVectorImpl<const SCEV *> &ToForget) {
8301 while (!Worklist.empty()) {
8302 Instruction *I = Worklist.pop_back_val();
8303 if (!isSCEVable(I->getType()))
8306 ValueExprMapType::iterator It =
8307 ValueExprMap.find_as(static_cast<Value *>(I));
8308 if (It != ValueExprMap.end()) {
8309 eraseValueFromMap(It->first);
8310 ToForget.push_back(It->second);
8311 if (PHINode *PN = dyn_cast<PHINode>(I))
8312 ConstantEvolutionLoopExitValue.erase(PN);
8315 PushDefUseChildren(I, Worklist, Visited);
8319 void ScalarEvolution::forgetLoop(const Loop *L) {
8320 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8321 SmallVector<Instruction *, 32> Worklist;
8322 SmallPtrSet<Instruction *, 16> Visited;
8323 SmallVector<const SCEV *, 16> ToForget;
8325 // Iterate over all the loops and sub-loops to drop SCEV information.
8326 while (!LoopWorklist.empty()) {
8327 auto *CurrL = LoopWorklist.pop_back_val();
8329 // Drop any stored trip count value.
8330 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8331 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8333 // Drop information about predicated SCEV rewrites for this loop.
8334 for (auto I = PredicatedSCEVRewrites.begin();
8335 I != PredicatedSCEVRewrites.end();) {
8336 std::pair<const SCEV *, const Loop *> Entry = I->first;
8337 if (Entry.second == CurrL)
8338 PredicatedSCEVRewrites.erase(I++);
8343 auto LoopUsersItr = LoopUsers.find(CurrL);
8344 if (LoopUsersItr != LoopUsers.end()) {
8345 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8346 LoopUsersItr->second.end());
8349 // Drop information about expressions based on loop-header PHIs.
8350 PushLoopPHIs(CurrL, Worklist, Visited);
8351 visitAndClearUsers(Worklist, Visited, ToForget);
8353 LoopPropertiesCache.erase(CurrL);
8354 // Forget all contained loops too, to avoid dangling entries in the
8355 // ValuesAtScopes map.
8356 LoopWorklist.append(CurrL->begin(), CurrL->end());
8358 forgetMemoizedResults(ToForget);
8361 void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
8362 forgetLoop(L->getOutermostLoop());
8365 void ScalarEvolution::forgetValue(Value *V) {
8366 Instruction *I = dyn_cast<Instruction>(V);
8369 // Drop information about expressions based on loop-header PHIs.
8370 SmallVector<Instruction *, 16> Worklist;
8371 SmallPtrSet<Instruction *, 8> Visited;
8372 SmallVector<const SCEV *, 8> ToForget;
8373 Worklist.push_back(I);
8375 visitAndClearUsers(Worklist, Visited, ToForget);
8377 forgetMemoizedResults(ToForget);
8380 void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8382 void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) {
8383 // Unless a specific value is passed to invalidation, completely clear both
8386 BlockDispositions.clear();
8387 LoopDispositions.clear();
8391 if (!isSCEVable(V->getType()))
8394 const SCEV *S = getExistingSCEV(V);
8398 // Invalidate the block and loop dispositions cached for S. Dispositions of
8399 // S's users may change if S's disposition changes (i.e. a user may change to
8400 // loop-invariant, if S changes to loop invariant), so also invalidate
8401 // dispositions of S's users recursively.
8402 SmallVector<const SCEV *, 8> Worklist = {S};
8403 SmallPtrSet<const SCEV *, 8> Seen = {S};
8404 while (!Worklist.empty()) {
8405 const SCEV *Curr = Worklist.pop_back_val();
8406 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8407 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8408 if (!LoopDispoRemoved && !BlockDispoRemoved)
8410 auto Users = SCEVUsers.find(Curr);
8411 if (Users != SCEVUsers.end())
8412 for (const auto *User : Users->second)
8413 if (Seen.insert(User).second)
8414 Worklist.push_back(User);
8418 /// Get the exact loop backedge taken count considering all loop exits. A
8419 /// computable result can only be returned for loops with all exiting blocks
8420 /// dominating the latch. howFarToZero assumes that the limit of each loop test
8421 /// is never skipped. This is a valid assumption as long as the loop exits via
8422 /// that test. For precise results, it is the caller's responsibility to specify
8423 /// the relevant loop exiting block using getExact(ExitingBlock, SE).
8425 ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8426 SmallVector<const SCEVPredicate *, 4> *Preds) const {
8427 // If any exits were not computable, the loop is not computable.
8428 if (!isComplete() || ExitNotTaken.empty())
8429 return SE->getCouldNotCompute();
8431 const BasicBlock *Latch = L->getLoopLatch();
8432 // All exiting blocks we have collected must dominate the only backedge.
8434 return SE->getCouldNotCompute();
8436 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8437 // count is simply a minimum out of all these calculated exit counts.
8438 SmallVector<const SCEV *, 2> Ops;
8439 for (const auto &ENT : ExitNotTaken) {
8440 const SCEV *BECount = ENT.ExactNotTaken;
8441 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8442 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8443 "We should only have known counts for exiting blocks that dominate "
8446 Ops.push_back(BECount);
8449 for (const auto *P : ENT.Predicates)
8450 Preds->push_back(P);
8452 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8453 "Predicate should be always true!");
8456 // If an earlier exit exits on the first iteration (exit count zero), then
8457 // a later poison exit count should not propagate into the result. This are
8458 // exactly the semantics provided by umin_seq.
8459 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8462 /// Get the exact not taken count for this loop exit.
8464 ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8465 ScalarEvolution *SE) const {
8466 for (const auto &ENT : ExitNotTaken)
8467 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8468 return ENT.ExactNotTaken;
8470 return SE->getCouldNotCompute();
8473 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8474 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8475 for (const auto &ENT : ExitNotTaken)
8476 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8477 return ENT.ConstantMaxNotTaken;
8479 return SE->getCouldNotCompute();
8482 const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8483 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8484 for (const auto &ENT : ExitNotTaken)
8485 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8486 return ENT.SymbolicMaxNotTaken;
8488 return SE->getCouldNotCompute();
8491 /// getConstantMax - Get the constant max backedge taken count for the loop.
8493 ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8494 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8495 return !ENT.hasAlwaysTruePredicate();
8498 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8499 return SE->getCouldNotCompute();
8501 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8502 isa<SCEVConstant>(getConstantMax())) &&
8503 "No point in having a non-constant max backedge taken count!");
8504 return getConstantMax();
8508 ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
8509 ScalarEvolution *SE) {
8511 SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
8515 bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8516 ScalarEvolution *SE) const {
8517 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8518 return !ENT.hasAlwaysTruePredicate();
8520 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8523 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
8524 : ExitLimit(E, E, E, false, std::nullopt) {}
8526 ScalarEvolution::ExitLimit::ExitLimit(
8527 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8528 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8529 ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
8530 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8531 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8532 // If we prove the max count is zero, so is the symbolic bound. This happens
8533 // in practice due to differences in a) how context sensitive we've chosen
8534 // to be and b) how we reason about bounds implied by UB.
8535 if (ConstantMaxNotTaken->isZero()) {
8536 this->ExactNotTaken = E = ConstantMaxNotTaken;
8537 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8540 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8541 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8542 "Exact is not allowed to be less precise than Constant Max");
8543 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8544 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8545 "Exact is not allowed to be less precise than Symbolic Max");
8546 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8547 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8548 "Symbolic Max is not allowed to be less precise than Constant Max");
8549 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8550 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8551 "No point in having a non-constant max backedge taken count!");
8552 for (const auto *PredSet : PredSetList)
8553 for (const auto *P : *PredSet)
8555 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8556 "Backedge count should be int");
8557 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8558 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8559 "Max backedge count should be int");
8562 ScalarEvolution::ExitLimit::ExitLimit(
8563 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8564 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8565 const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
8566 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8569 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8570 /// computable exit into a persistent ExitNotTakenInfo array.
8571 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8572 ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts,
8573 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8574 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8575 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8577 ExitNotTaken.reserve(ExitCounts.size());
8578 std::transform(ExitCounts.begin(), ExitCounts.end(),
8579 std::back_inserter(ExitNotTaken),
8580 [&](const EdgeExitInfo &EEI) {
8581 BasicBlock *ExitBB = EEI.first;
8582 const ExitLimit &EL = EEI.second;
8583 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8584 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8587 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8588 isa<SCEVConstant>(ConstantMax)) &&
8589 "No point in having a non-constant max backedge taken count!");
8592 /// Compute the number of times the backedge of the specified loop will execute.
8593 ScalarEvolution::BackedgeTakenInfo
8594 ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8595 bool AllowPredicates) {
8596 SmallVector<BasicBlock *, 8> ExitingBlocks;
8597 L->getExitingBlocks(ExitingBlocks);
8599 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8601 SmallVector<EdgeExitInfo, 4> ExitCounts;
8602 bool CouldComputeBECount = true;
8603 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8604 const SCEV *MustExitMaxBECount = nullptr;
8605 const SCEV *MayExitMaxBECount = nullptr;
8606 bool MustExitMaxOrZero = false;
8608 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8609 // and compute maxBECount.
8610 // Do a union of all the predicates here.
8611 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
8612 BasicBlock *ExitBB = ExitingBlocks[i];
8614 // We canonicalize untaken exits to br (constant), ignore them so that
8615 // proving an exit untaken doesn't negatively impact our ability to reason
8616 // about the loop as whole.
8617 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8618 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8619 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8620 if (ExitIfTrue == CI->isZero())
8624 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
8626 assert((AllowPredicates || EL.Predicates.empty()) &&
8627 "Predicated exit limit when predicates are not allowed!");
8629 // 1. For each exit that can be computed, add an entry to ExitCounts.
8630 // CouldComputeBECount is true only if all exits can be computed.
8631 if (EL.ExactNotTaken != getCouldNotCompute())
8632 ++NumExitCountsComputed;
8634 // We couldn't compute an exact value for this exit, so
8635 // we won't be able to compute an exact value for the loop.
8636 CouldComputeBECount = false;
8637 // Remember exit count if either exact or symbolic is known. Because
8638 // Exact always implies symbolic, only check symbolic.
8639 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8640 ExitCounts.emplace_back(ExitBB, EL);
8642 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8643 "Exact is known but symbolic isn't?");
8644 ++NumExitCountsNotComputed;
8647 // 2. Derive the loop's MaxBECount from each exit's max number of
8648 // non-exiting iterations. Partition the loop exits into two kinds:
8649 // LoopMustExits and LoopMayExits.
8651 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8652 // is a LoopMayExit. If any computable LoopMustExit is found, then
8653 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8654 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8655 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8657 // computable EL.ConstantMaxNotTaken.
8658 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8659 DT.dominates(ExitBB, Latch)) {
8660 if (!MustExitMaxBECount) {
8661 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8662 MustExitMaxOrZero = EL.MaxOrZero;
8664 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8665 EL.ConstantMaxNotTaken);
8667 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8668 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8669 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8671 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8672 EL.ConstantMaxNotTaken);
8676 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8677 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8678 // The loop backedge will be taken the maximum or zero times if there's
8679 // a single exit that must be taken the maximum or zero times.
8680 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8682 // Remember which SCEVs are used in exit limits for invalidation purposes.
8683 // We only care about non-constant SCEVs here, so we can ignore
8684 // EL.ConstantMaxNotTaken
8685 // and MaxBECount, which must be SCEVConstant.
8686 for (const auto &Pair : ExitCounts) {
8687 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8688 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8689 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8690 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8691 {L, AllowPredicates});
8693 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8694 MaxBECount, MaxOrZero);
8697 ScalarEvolution::ExitLimit
8698 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8699 bool AllowPredicates) {
8700 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8701 // If our exiting block does not dominate the latch, then its connection with
8702 // loop's exit limit may be far from trivial.
8703 const BasicBlock *Latch = L->getLoopLatch();
8704 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8705 return getCouldNotCompute();
8707 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
8708 Instruction *Term = ExitingBlock->getTerminator();
8709 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8710 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8711 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8712 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8713 "It should have one successor in loop and one exit block!");
8714 // Proceed to the next level to examine the exit condition expression.
8715 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8716 /*ControlsOnlyExit=*/IsOnlyExit,
8720 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8721 // For switch, make sure that there is a single exit from the loop.
8722 BasicBlock *Exit = nullptr;
8723 for (auto *SBB : successors(ExitingBlock))
8724 if (!L->contains(SBB)) {
8725 if (Exit) // Multiple exit successors.
8726 return getCouldNotCompute();
8729 assert(Exit && "Exiting block must have at least one exit");
8730 return computeExitLimitFromSingleExitSwitch(
8732 /*ControlsOnlyExit=*/IsOnlyExit);
8735 return getCouldNotCompute();
8738 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
8739 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8740 bool AllowPredicates) {
8741 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8742 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8743 ControlsOnlyExit, AllowPredicates);
8746 std::optional<ScalarEvolution::ExitLimit>
8747 ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8748 bool ExitIfTrue, bool ControlsOnlyExit,
8749 bool AllowPredicates) {
8751 (void)this->ExitIfTrue;
8752 (void)this->AllowPredicates;
8754 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8755 this->AllowPredicates == AllowPredicates &&
8756 "Variance in assumed invariant key components!");
8757 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8758 if (Itr == TripCountMap.end())
8759 return std::nullopt;
8763 void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8765 bool ControlsOnlyExit,
8766 bool AllowPredicates,
8767 const ExitLimit &EL) {
8768 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8769 this->AllowPredicates == AllowPredicates &&
8770 "Variance in assumed invariant key components!");
8772 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
8773 assert(InsertResult.second && "Expected successful insertion!");
8778 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8779 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8780 bool ControlsOnlyExit, bool AllowPredicates) {
8782 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8786 ExitLimit EL = computeExitLimitFromCondImpl(
8787 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
8788 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8792 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8793 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8794 bool ControlsOnlyExit, bool AllowPredicates) {
8795 // Handle BinOp conditions (And, Or).
8796 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8797 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
8798 return *LimitFromBinOp;
8800 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8801 // Proceed to the next level to examine the icmp.
8802 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8804 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
8805 if (EL.hasFullInfo() || !AllowPredicates)
8808 // Try again, but use SCEV predicates this time.
8809 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
8811 /*AllowPredicates=*/true);
8814 // Check for a constant condition. These are normally stripped out by
8815 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8816 // preserve the CFG and is temporarily leaving constant conditions
8818 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8819 if (ExitIfTrue == !CI->getZExtValue())
8820 // The backedge is always taken.
8821 return getCouldNotCompute();
8822 // The backedge is never taken.
8823 return getZero(CI->getType());
8826 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8827 // with a constant step, we can form an equivalent icmp predicate and figure
8828 // out how many iterations will be taken before we exit.
8829 const WithOverflowInst *WO;
8831 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8832 match(WO->getRHS(), m_APInt(C))) {
8834 ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
8835 WO->getNoWrapKind());
8836 CmpInst::Predicate Pred;
8837 APInt NewRHSC, Offset;
8838 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8840 Pred = ICmpInst::getInversePredicate(Pred);
8841 auto *LHS = getSCEV(WO->getLHS());
8843 LHS = getAddExpr(LHS, getConstant(Offset));
8844 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8845 ControlsOnlyExit, AllowPredicates);
8846 if (EL.hasAnyInfo())
8850 // If it's not an integer or pointer comparison then compute it the hard way.
8851 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8854 std::optional<ScalarEvolution::ExitLimit>
8855 ScalarEvolution::computeExitLimitFromCondFromBinOp(
8856 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8857 bool ControlsOnlyExit, bool AllowPredicates) {
8858 // Check if the controlling expression for this loop is an And or Or.
8861 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
8863 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
8866 return std::nullopt;
8868 // EitherMayExit is true in these two cases:
8869 // br (and Op0 Op1), loop, exit
8870 // br (or Op0 Op1), exit, loop
8871 bool EitherMayExit = IsAnd ^ ExitIfTrue;
8872 ExitLimit EL0 = computeExitLimitFromCondCached(
8873 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
8875 ExitLimit EL1 = computeExitLimitFromCondCached(
8876 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
8879 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
8880 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
8881 if (isa<ConstantInt>(Op1))
8882 return Op1 == NeutralElement ? EL0 : EL1;
8883 if (isa<ConstantInt>(Op0))
8884 return Op0 == NeutralElement ? EL1 : EL0;
8886 const SCEV *BECount = getCouldNotCompute();
8887 const SCEV *ConstantMaxBECount = getCouldNotCompute();
8888 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
8889 if (EitherMayExit) {
8890 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
8891 // Both conditions must be same for the loop to continue executing.
8892 // Choose the less conservative count.
8893 if (EL0.ExactNotTaken != getCouldNotCompute() &&
8894 EL1.ExactNotTaken != getCouldNotCompute()) {
8895 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
8898 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
8899 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
8900 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
8901 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
8903 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
8904 EL1.ConstantMaxNotTaken);
8905 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
8906 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
8907 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
8908 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
8910 SymbolicMaxBECount = getUMinFromMismatchedTypes(
8911 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
8913 // Both conditions must be same at the same time for the loop to exit.
8914 // For now, be conservative.
8915 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
8916 BECount = EL0.ExactNotTaken;
8919 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
8920 // to be more aggressive when computing BECount than when computing
8921 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
8923 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
8924 // EL1.ConstantMaxNotTaken to not.
8925 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
8926 !isa<SCEVCouldNotCompute>(BECount))
8927 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
8928 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
8929 SymbolicMaxBECount =
8930 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
8931 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
8932 { &EL0.Predicates, &EL1.Predicates });
8935 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
8936 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8937 bool AllowPredicates) {
8938 // If the condition was exit on true, convert the condition to exit on false
8939 ICmpInst::Predicate Pred;
8941 Pred = ExitCond->getPredicate();
8943 Pred = ExitCond->getInversePredicate();
8944 const ICmpInst::Predicate OriginalPred = Pred;
8946 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
8947 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
8949 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
8951 if (EL.hasAnyInfo())
8954 auto *ExhaustiveCount =
8955 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8957 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
8958 return ExhaustiveCount;
8960 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
8961 ExitCond->getOperand(1), L, OriginalPred);
8963 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
8964 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
8965 bool ControlsOnlyExit, bool AllowPredicates) {
8967 // Try to evaluate any dependencies out of the loop.
8968 LHS = getSCEVAtScope(LHS, L);
8969 RHS = getSCEVAtScope(RHS, L);
8971 // At this point, we would like to compute how many iterations of the
8972 // loop the predicate will return true for these inputs.
8973 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
8974 // If there is a loop-invariant, force it into the RHS.
8975 std::swap(LHS, RHS);
8976 Pred = ICmpInst::getSwappedPredicate(Pred);
8979 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
8980 loopIsFiniteByAssumption(L);
8981 // Simplify the operands before analyzing them.
8982 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
8984 // If we have a comparison of a chrec against a constant, try to use value
8985 // ranges to answer this query.
8986 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
8987 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
8988 if (AddRec->getLoop() == L) {
8989 // Form the constant range.
8990 ConstantRange CompRange =
8991 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
8993 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
8994 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
8997 // If this loop must exit based on this condition (or execute undefined
8998 // behaviour), and we can prove the test sequence produced must repeat
8999 // the same values on self-wrap of the IV, then we can infer that IV
9000 // doesn't self wrap because if it did, we'd have an infinite (undefined)
9002 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9003 // TODO: We can peel off any functions which are invertible *in L*. Loop
9004 // invariant terms are effectively constants for our purposes here.
9005 auto *InnerLHS = LHS;
9006 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9007 InnerLHS = ZExt->getOperand();
9008 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
9009 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
9010 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9011 StrideC && StrideC->getAPInt().isPowerOf2()) {
9012 auto Flags = AR->getNoWrapFlags();
9013 Flags = setFlags(Flags, SCEV::FlagNW);
9014 SmallVector<const SCEV*> Operands{AR->operands()};
9015 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9016 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9022 case ICmpInst::ICMP_NE: { // while (X != Y)
9023 // Convert to: while (X-Y != 0)
9024 if (LHS->getType()->isPointerTy()) {
9025 LHS = getLosslessPtrToIntExpr(LHS);
9026 if (isa<SCEVCouldNotCompute>(LHS))
9029 if (RHS->getType()->isPointerTy()) {
9030 RHS = getLosslessPtrToIntExpr(RHS);
9031 if (isa<SCEVCouldNotCompute>(RHS))
9034 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9036 if (EL.hasAnyInfo())
9040 case ICmpInst::ICMP_EQ: { // while (X == Y)
9041 // Convert to: while (X-Y == 0)
9042 if (LHS->getType()->isPointerTy()) {
9043 LHS = getLosslessPtrToIntExpr(LHS);
9044 if (isa<SCEVCouldNotCompute>(LHS))
9047 if (RHS->getType()->isPointerTy()) {
9048 RHS = getLosslessPtrToIntExpr(RHS);
9049 if (isa<SCEVCouldNotCompute>(RHS))
9052 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9053 if (EL.hasAnyInfo()) return EL;
9056 case ICmpInst::ICMP_SLE:
9057 case ICmpInst::ICMP_ULE:
9058 // Since the loop is finite, an invariant RHS cannot include the boundary
9059 // value, otherwise it would loop forever.
9060 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9061 !isLoopInvariant(RHS, L))
9063 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9065 case ICmpInst::ICMP_SLT:
9066 case ICmpInst::ICMP_ULT: { // while (X < Y)
9067 bool IsSigned = ICmpInst::isSigned(Pred);
9068 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9070 if (EL.hasAnyInfo())
9074 case ICmpInst::ICMP_SGE:
9075 case ICmpInst::ICMP_UGE:
9076 // Since the loop is finite, an invariant RHS cannot include the boundary
9077 // value, otherwise it would loop forever.
9078 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9079 !isLoopInvariant(RHS, L))
9081 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9083 case ICmpInst::ICMP_SGT:
9084 case ICmpInst::ICMP_UGT: { // while (X > Y)
9085 bool IsSigned = ICmpInst::isSigned(Pred);
9086 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9088 if (EL.hasAnyInfo())
9096 return getCouldNotCompute();
9099 ScalarEvolution::ExitLimit
9100 ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9102 BasicBlock *ExitingBlock,
9103 bool ControlsOnlyExit) {
9104 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9106 // Give up if the exit is the default dest of a switch.
9107 if (Switch->getDefaultDest() == ExitingBlock)
9108 return getCouldNotCompute();
9110 assert(L->contains(Switch->getDefaultDest()) &&
9111 "Default case must not exit the loop!");
9112 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9113 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9115 // while (X != Y) --> while (X-Y != 0)
9116 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9117 if (EL.hasAnyInfo())
9120 return getCouldNotCompute();
9123 static ConstantInt *
9124 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
9125 ScalarEvolution &SE) {
9126 const SCEV *InVal = SE.getConstant(C);
9127 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9128 assert(isa<SCEVConstant>(Val) &&
9129 "Evaluation of SCEV at constant didn't fold correctly?");
9130 return cast<SCEVConstant>(Val)->getValue();
9133 ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9134 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9135 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9137 return getCouldNotCompute();
9139 const BasicBlock *Latch = L->getLoopLatch();
9141 return getCouldNotCompute();
9143 const BasicBlock *Predecessor = L->getLoopPredecessor();
9145 return getCouldNotCompute();
9147 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9148 // Return LHS in OutLHS and shift_opt in OutOpCode.
9149 auto MatchPositiveShift =
9150 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9152 using namespace PatternMatch;
9154 ConstantInt *ShiftAmt;
9155 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9156 OutOpCode = Instruction::LShr;
9157 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9158 OutOpCode = Instruction::AShr;
9159 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9160 OutOpCode = Instruction::Shl;
9164 return ShiftAmt->getValue().isStrictlyPositive();
9167 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9170 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9171 // %iv.shifted = lshr i32 %iv, <positive constant>
9173 // Return true on a successful match. Return the corresponding PHI node (%iv
9174 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9175 auto MatchShiftRecurrence =
9176 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9177 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9180 Instruction::BinaryOps OpC;
9183 // If we encounter a shift instruction, "peel off" the shift operation,
9184 // and remember that we did so. Later when we inspect %iv's backedge
9185 // value, we will make sure that the backedge value uses the same
9188 // Note: the peeled shift operation does not have to be the same
9189 // instruction as the one feeding into the PHI's backedge value. We only
9190 // really care about it being the same *kind* of shift instruction --
9191 // that's all that is required for our later inferences to hold.
9192 if (MatchPositiveShift(LHS, V, OpC)) {
9193 PostShiftOpCode = OpC;
9198 PNOut = dyn_cast<PHINode>(LHS);
9199 if (!PNOut || PNOut->getParent() != L->getHeader())
9202 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9206 // The backedge value for the PHI node must be a shift by a positive
9208 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9210 // of the PHI node itself
9213 // and the kind of shift should be match the kind of shift we peeled
9215 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9219 Instruction::BinaryOps OpCode;
9220 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9221 return getCouldNotCompute();
9223 const DataLayout &DL = getDataLayout();
9225 // The key rationale for this optimization is that for some kinds of shift
9226 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9227 // within a finite number of iterations. If the condition guarding the
9228 // backedge (in the sense that the backedge is taken if the condition is true)
9229 // is false for the value the shift recurrence stabilizes to, then we know
9230 // that the backedge is taken only a finite number of times.
9232 ConstantInt *StableValue = nullptr;
9235 llvm_unreachable("Impossible case!");
9237 case Instruction::AShr: {
9238 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9239 // bitwidth(K) iterations.
9240 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9241 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9242 Predecessor->getTerminator(), &DT);
9243 auto *Ty = cast<IntegerType>(RHS->getType());
9244 if (Known.isNonNegative())
9245 StableValue = ConstantInt::get(Ty, 0);
9246 else if (Known.isNegative())
9247 StableValue = ConstantInt::get(Ty, -1, true);
9249 return getCouldNotCompute();
9253 case Instruction::LShr:
9254 case Instruction::Shl:
9255 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9256 // stabilize to 0 in at most bitwidth(K) iterations.
9257 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9262 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9263 assert(Result->getType()->isIntegerTy(1) &&
9264 "Otherwise cannot be an operand to a branch instruction");
9266 if (Result->isZeroValue()) {
9267 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9268 const SCEV *UpperBound =
9269 getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
9270 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9273 return getCouldNotCompute();
9276 /// Return true if we can constant fold an instruction of the specified type,
9277 /// assuming that all operands were constants.
9278 static bool CanConstantFold(const Instruction *I) {
9279 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9280 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9281 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9284 if (const CallInst *CI = dyn_cast<CallInst>(I))
9285 if (const Function *F = CI->getCalledFunction())
9286 return canConstantFoldCallTo(CI, F);
9290 /// Determine whether this instruction can constant evolve within this loop
9291 /// assuming its operands can all constant evolve.
9292 static bool canConstantEvolve(Instruction *I, const Loop *L) {
9293 // An instruction outside of the loop can't be derived from a loop PHI.
9294 if (!L->contains(I)) return false;
9296 if (isa<PHINode>(I)) {
9297 // We don't currently keep track of the control flow needed to evaluate
9298 // PHIs, so we cannot handle PHIs inside of loops.
9299 return L->getHeader() == I->getParent();
9302 // If we won't be able to constant fold this expression even if the operands
9303 // are constants, bail early.
9304 return CanConstantFold(I);
9307 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9308 /// recursing through each instruction operand until reaching a loop header phi.
9310 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
9311 DenseMap<Instruction *, PHINode *> &PHIMap,
9313 if (Depth > MaxConstantEvolvingDepth)
9316 // Otherwise, we can evaluate this instruction if all of its operands are
9317 // constant or derived from a PHI node themselves.
9318 PHINode *PHI = nullptr;
9319 for (Value *Op : UseInst->operands()) {
9320 if (isa<Constant>(Op)) continue;
9322 Instruction *OpInst = dyn_cast<Instruction>(Op);
9323 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9325 PHINode *P = dyn_cast<PHINode>(OpInst);
9327 // If this operand is already visited, reuse the prior result.
9328 // We may have P != PHI if this is the deepest point at which the
9329 // inconsistent paths meet.
9330 P = PHIMap.lookup(OpInst);
9332 // Recurse and memoize the results, whether a phi is found or not.
9333 // This recursive call invalidates pointers into PHIMap.
9334 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9338 return nullptr; // Not evolving from PHI
9339 if (PHI && PHI != P)
9340 return nullptr; // Evolving from multiple different PHIs.
9343 // This is a expression evolving from a constant PHI!
9347 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9348 /// in the loop that V is derived from. We allow arbitrary operations along the
9349 /// way, but the operands of an operation must either be constants or a value
9350 /// derived from a constant PHI. If this expression does not fit with these
9351 /// constraints, return null.
9352 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
9353 Instruction *I = dyn_cast<Instruction>(V);
9354 if (!I || !canConstantEvolve(I, L)) return nullptr;
9356 if (PHINode *PN = dyn_cast<PHINode>(I))
9359 // Record non-constant instructions contained by the loop.
9360 DenseMap<Instruction *, PHINode *> PHIMap;
9361 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9364 /// EvaluateExpression - Given an expression that passes the
9365 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9366 /// in the loop has the value PHIVal. If we can't fold this expression for some
9367 /// reason, return null.
9368 static Constant *EvaluateExpression(Value *V, const Loop *L,
9369 DenseMap<Instruction *, Constant *> &Vals,
9370 const DataLayout &DL,
9371 const TargetLibraryInfo *TLI) {
9372 // Convenient constant check, but redundant for recursive calls.
9373 if (Constant *C = dyn_cast<Constant>(V)) return C;
9374 Instruction *I = dyn_cast<Instruction>(V);
9375 if (!I) return nullptr;
9377 if (Constant *C = Vals.lookup(I)) return C;
9379 // An instruction inside the loop depends on a value outside the loop that we
9380 // weren't given a mapping for, or a value such as a call inside the loop.
9381 if (!canConstantEvolve(I, L)) return nullptr;
9383 // An unmapped PHI can be due to a branch or another loop inside this loop,
9384 // or due to this not being the initial iteration through a loop where we
9385 // couldn't compute the evolution of this particular PHI last time.
9386 if (isa<PHINode>(I)) return nullptr;
9388 std::vector<Constant*> Operands(I->getNumOperands());
9390 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9391 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9393 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9394 if (!Operands[i]) return nullptr;
9397 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9399 if (!C) return nullptr;
9403 return ConstantFoldInstOperands(I, Operands, DL, TLI);
9407 // If every incoming value to PN except the one for BB is a specific Constant,
9408 // return that, else return nullptr.
9409 static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
9410 Constant *IncomingVal = nullptr;
9412 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9413 if (PN->getIncomingBlock(i) == BB)
9416 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9420 if (IncomingVal != CurrentVal) {
9423 IncomingVal = CurrentVal;
9430 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9431 /// in the header of its containing loop, we know the loop executes a
9432 /// constant number of times, and the PHI node is just a recurrence
9433 /// involving constants, fold it.
9435 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9438 auto I = ConstantEvolutionLoopExitValue.find(PN);
9439 if (I != ConstantEvolutionLoopExitValue.end())
9442 if (BEs.ugt(MaxBruteForceIterations))
9443 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9445 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9447 DenseMap<Instruction *, Constant *> CurrentIterVals;
9448 BasicBlock *Header = L->getHeader();
9449 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9451 BasicBlock *Latch = L->getLoopLatch();
9455 for (PHINode &PHI : Header->phis()) {
9456 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9457 CurrentIterVals[&PHI] = StartCST;
9459 if (!CurrentIterVals.count(PN))
9460 return RetVal = nullptr;
9462 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9464 // Execute the loop symbolically to determine the exit value.
9465 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9466 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9468 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9469 unsigned IterationNum = 0;
9470 const DataLayout &DL = getDataLayout();
9471 for (; ; ++IterationNum) {
9472 if (IterationNum == NumIterations)
9473 return RetVal = CurrentIterVals[PN]; // Got exit value!
9475 // Compute the value of the PHIs for the next iteration.
9476 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9477 DenseMap<Instruction *, Constant *> NextIterVals;
9479 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9481 return nullptr; // Couldn't evaluate!
9482 NextIterVals[PN] = NextPHI;
9484 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9486 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9487 // cease to be able to evaluate one of them or if they stop evolving,
9488 // because that doesn't necessarily prevent us from computing PN.
9489 SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
9490 for (const auto &I : CurrentIterVals) {
9491 PHINode *PHI = dyn_cast<PHINode>(I.first);
9492 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9493 PHIsToCompute.emplace_back(PHI, I.second);
9495 // We use two distinct loops because EvaluateExpression may invalidate any
9496 // iterators into CurrentIterVals.
9497 for (const auto &I : PHIsToCompute) {
9498 PHINode *PHI = I.first;
9499 Constant *&NextPHI = NextIterVals[PHI];
9500 if (!NextPHI) { // Not already computed.
9501 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9502 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9504 if (NextPHI != I.second)
9505 StoppedEvolving = false;
9508 // If all entries in CurrentIterVals == NextIterVals then we can stop
9509 // iterating, the loop can't continue to change.
9510 if (StoppedEvolving)
9511 return RetVal = CurrentIterVals[PN];
9513 CurrentIterVals.swap(NextIterVals);
9517 const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9520 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9521 if (!PN) return getCouldNotCompute();
9523 // If the loop is canonicalized, the PHI will have exactly two entries.
9524 // That's the only form we support here.
9525 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9527 DenseMap<Instruction *, Constant *> CurrentIterVals;
9528 BasicBlock *Header = L->getHeader();
9529 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9531 BasicBlock *Latch = L->getLoopLatch();
9532 assert(Latch && "Should follow from NumIncomingValues == 2!");
9534 for (PHINode &PHI : Header->phis()) {
9535 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9536 CurrentIterVals[&PHI] = StartCST;
9538 if (!CurrentIterVals.count(PN))
9539 return getCouldNotCompute();
9541 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9542 // the loop symbolically to determine when the condition gets a value of
9544 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9545 const DataLayout &DL = getDataLayout();
9546 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9547 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9548 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9550 // Couldn't symbolically evaluate.
9551 if (!CondVal) return getCouldNotCompute();
9553 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9554 ++NumBruteForceTripCountsComputed;
9555 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9558 // Update all the PHI nodes for the next iteration.
9559 DenseMap<Instruction *, Constant *> NextIterVals;
9561 // Create a list of which PHIs we need to compute. We want to do this before
9562 // calling EvaluateExpression on them because that may invalidate iterators
9563 // into CurrentIterVals.
9564 SmallVector<PHINode *, 8> PHIsToCompute;
9565 for (const auto &I : CurrentIterVals) {
9566 PHINode *PHI = dyn_cast<PHINode>(I.first);
9567 if (!PHI || PHI->getParent() != Header) continue;
9568 PHIsToCompute.push_back(PHI);
9570 for (PHINode *PHI : PHIsToCompute) {
9571 Constant *&NextPHI = NextIterVals[PHI];
9572 if (NextPHI) continue; // Already computed!
9574 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9575 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9577 CurrentIterVals.swap(NextIterVals);
9580 // Too many iterations were needed to evaluate.
9581 return getCouldNotCompute();
9584 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9585 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
9587 // Check to see if we've folded this expression at this loop before.
9588 for (auto &LS : Values)
9590 return LS.second ? LS.second : V;
9592 Values.emplace_back(L, nullptr);
9594 // Otherwise compute it.
9595 const SCEV *C = computeSCEVAtScope(V, L);
9596 for (auto &LS : reverse(ValuesAtScopes[V]))
9597 if (LS.first == L) {
9599 if (!isa<SCEVConstant>(C))
9600 ValuesAtScopesUsers[C].push_back({L, V});
9606 /// This builds up a Constant using the ConstantExpr interface. That way, we
9607 /// will return Constants for objects which aren't represented by a
9608 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9609 /// Returns NULL if the SCEV isn't representable as a Constant.
9610 static Constant *BuildConstantFromSCEV(const SCEV *V) {
9611 switch (V->getSCEVType()) {
9612 case scCouldNotCompute:
9617 return cast<SCEVConstant>(V)->getValue();
9619 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9620 case scSignExtend: {
9621 const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
9622 if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
9623 return ConstantExpr::getSExt(CastOp, SS->getType());
9626 case scZeroExtend: {
9627 const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
9628 if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
9629 return ConstantExpr::getZExt(CastOp, SZ->getType());
9633 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9634 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9635 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9640 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9641 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9642 return ConstantExpr::getTrunc(CastOp, ST->getType());
9646 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9647 Constant *C = nullptr;
9648 for (const SCEV *Op : SA->operands()) {
9649 Constant *OpC = BuildConstantFromSCEV(Op);
9656 assert(!C->getType()->isPointerTy() &&
9657 "Can only have one pointer, and it must be last");
9658 if (auto *PT = dyn_cast<PointerType>(OpC->getType())) {
9659 // The offsets have been converted to bytes. We can add bytes to an
9660 // i8* by GEP with the byte count in the first index.
9662 Type::getInt8PtrTy(PT->getContext(), PT->getAddressSpace());
9663 OpC = ConstantExpr::getBitCast(OpC, DestPtrTy);
9664 C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()),
9667 C = ConstantExpr::getAdd(C, OpC);
9673 const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
9674 Constant *C = nullptr;
9675 for (const SCEV *Op : SM->operands()) {
9676 assert(!Op->getType()->isPointerTy() && "Can't multiply pointers");
9677 Constant *OpC = BuildConstantFromSCEV(Op);
9680 C = C ? ConstantExpr::getMul(C, OpC) : OpC;
9689 case scSequentialUMinExpr:
9690 return nullptr; // TODO: smax, umax, smin, umax, umin_seq.
9692 llvm_unreachable("Unknown SCEV kind!");
9696 ScalarEvolution::getWithOperands(const SCEV *S,
9697 SmallVectorImpl<const SCEV *> &NewOps) {
9698 switch (S->getSCEVType()) {
9703 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9704 case scAddRecExpr: {
9705 auto *AddRec = cast<SCEVAddRecExpr>(S);
9706 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9709 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9711 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9713 return getUDivExpr(NewOps[0], NewOps[1]);
9718 return getMinMaxExpr(S->getSCEVType(), NewOps);
9719 case scSequentialUMinExpr:
9720 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9725 case scCouldNotCompute:
9726 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9728 llvm_unreachable("Unknown SCEV kind!");
9731 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9732 switch (V->getSCEVType()) {
9736 case scAddRecExpr: {
9737 // If this is a loop recurrence for a loop that does not contain L, then we
9738 // are dealing with the final value computed by the loop.
9739 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9740 // First, attempt to evaluate each operand.
9741 // Avoid performing the look-up in the common case where the specified
9742 // expression has no loop-variant portions.
9743 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9744 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9745 if (OpAtScope == AddRec->getOperand(i))
9748 // Okay, at least one of these operands is loop variant but might be
9749 // foldable. Build a new instance of the folded commutative expression.
9750 SmallVector<const SCEV *, 8> NewOps;
9751 NewOps.reserve(AddRec->getNumOperands());
9752 append_range(NewOps, AddRec->operands().take_front(i));
9753 NewOps.push_back(OpAtScope);
9754 for (++i; i != e; ++i)
9755 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9757 const SCEV *FoldedRec = getAddRecExpr(
9758 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
9759 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9760 // The addrec may be folded to a nonrecurrence, for example, if the
9761 // induction variable is multiplied by zero after constant folding. Go
9762 // ahead and return the folded value.
9768 // If the scope is outside the addrec's loop, evaluate it by using the
9769 // loop exit value of the addrec.
9770 if (!AddRec->getLoop()->contains(L)) {
9771 // To evaluate this recurrence, we need to know how many times the AddRec
9772 // loop iterates. Compute this now.
9773 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9774 if (BackedgeTakenCount == getCouldNotCompute())
9777 // Then, evaluate the AddRec.
9778 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9794 case scSequentialUMinExpr: {
9795 ArrayRef<const SCEV *> Ops = V->operands();
9796 // Avoid performing the look-up in the common case where the specified
9797 // expression has no loop-variant portions.
9798 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
9799 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
9800 if (OpAtScope != Ops[i]) {
9801 // Okay, at least one of these operands is loop variant but might be
9802 // foldable. Build a new instance of the folded commutative expression.
9803 SmallVector<const SCEV *, 8> NewOps;
9804 NewOps.reserve(Ops.size());
9805 append_range(NewOps, Ops.take_front(i));
9806 NewOps.push_back(OpAtScope);
9808 for (++i; i != e; ++i) {
9809 OpAtScope = getSCEVAtScope(Ops[i], L);
9810 NewOps.push_back(OpAtScope);
9813 return getWithOperands(V, NewOps);
9816 // If we got here, all operands are loop invariant.
9820 // If this instruction is evolved from a constant-evolving PHI, compute the
9821 // exit value from the loop without using SCEVs.
9822 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
9823 Instruction *I = dyn_cast<Instruction>(SU->getValue());
9825 return V; // This is some other type of SCEVUnknown, just return it.
9827 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9828 const Loop *CurrLoop = this->LI[I->getParent()];
9829 // Looking for loop exit value.
9830 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9831 PN->getParent() == CurrLoop->getHeader()) {
9832 // Okay, there is no closed form solution for the PHI node. Check
9833 // to see if the loop that contains it has a known backedge-taken
9834 // count. If so, we may be able to force computation of the exit
9836 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9837 // This trivial case can show up in some degenerate cases where
9838 // the incoming IR has not yet been fully simplified.
9839 if (BackedgeTakenCount->isZero()) {
9840 Value *InitValue = nullptr;
9841 bool MultipleInitValues = false;
9842 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9843 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9845 InitValue = PN->getIncomingValue(i);
9846 else if (InitValue != PN->getIncomingValue(i)) {
9847 MultipleInitValues = true;
9852 if (!MultipleInitValues && InitValue)
9853 return getSCEV(InitValue);
9855 // Do we have a loop invariant value flowing around the backedge
9856 // for a loop which must execute the backedge?
9857 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
9858 isKnownPositive(BackedgeTakenCount) &&
9859 PN->getNumIncomingValues() == 2) {
9861 unsigned InLoopPred =
9862 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
9863 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
9864 if (CurrLoop->isLoopInvariant(BackedgeVal))
9865 return getSCEV(BackedgeVal);
9867 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
9868 // Okay, we know how many times the containing loop executes. If
9869 // this is a constant evolving PHI node, get the final value at
9870 // the specified iteration number.
9872 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
9879 // Okay, this is an expression that we cannot symbolically evaluate
9880 // into a SCEV. Check to see if it's possible to symbolically evaluate
9881 // the arguments into constants, and if so, try to constant propagate the
9882 // result. This is particularly useful for computing loop exit values.
9883 if (!CanConstantFold(I))
9884 return V; // This is some other type of SCEVUnknown, just return it.
9886 SmallVector<Constant *, 4> Operands;
9887 Operands.reserve(I->getNumOperands());
9888 bool MadeImprovement = false;
9889 for (Value *Op : I->operands()) {
9890 if (Constant *C = dyn_cast<Constant>(Op)) {
9891 Operands.push_back(C);
9895 // If any of the operands is non-constant and if they are
9896 // non-integer and non-pointer, don't even try to analyze them
9897 // with scev techniques.
9898 if (!isSCEVable(Op->getType()))
9901 const SCEV *OrigV = getSCEV(Op);
9902 const SCEV *OpV = getSCEVAtScope(OrigV, L);
9903 MadeImprovement |= OrigV != OpV;
9905 Constant *C = BuildConstantFromSCEV(OpV);
9908 if (C->getType() != Op->getType())
9909 C = ConstantExpr::getCast(
9910 CastInst::getCastOpcode(C, false, Op->getType(), false), C,
9912 Operands.push_back(C);
9915 // Check to see if getSCEVAtScope actually made an improvement.
9916 if (!MadeImprovement)
9917 return V; // This is some other type of SCEVUnknown, just return it.
9919 Constant *C = nullptr;
9920 const DataLayout &DL = getDataLayout();
9921 C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
9926 case scCouldNotCompute:
9927 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9929 llvm_unreachable("Unknown SCEV type!");
9932 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
9933 return getSCEVAtScope(getSCEV(V), L);
9936 const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
9937 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
9938 return stripInjectiveFunctions(ZExt->getOperand());
9939 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
9940 return stripInjectiveFunctions(SExt->getOperand());
9944 /// Finds the minimum unsigned root of the following equation:
9946 /// A * X = B (mod N)
9948 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
9949 /// A and B isn't important.
9951 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
9952 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
9953 ScalarEvolution &SE) {
9954 uint32_t BW = A.getBitWidth();
9955 assert(BW == SE.getTypeSizeInBits(B->getType()));
9956 assert(A != 0 && "A must be non-zero.");
9960 // The gcd of A and N may have only one prime factor: 2. The number of
9961 // trailing zeros in A is its multiplicity
9962 uint32_t Mult2 = A.countr_zero();
9965 // 2. Check if B is divisible by D.
9967 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
9968 // is not less than multiplicity of this prime factor for D.
9969 if (SE.getMinTrailingZeros(B) < Mult2)
9970 return SE.getCouldNotCompute();
9972 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
9975 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
9976 // (N / D) in general. The inverse itself always fits into BW bits, though,
9977 // so we immediately truncate it.
9978 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
9979 APInt Mod(BW + 1, 0);
9980 Mod.setBit(BW - Mult2); // Mod = N / D
9981 APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
9983 // 4. Compute the minimum unsigned root of the equation:
9984 // I * (B / D) mod (N / D)
9985 // To simplify the computation, we factor out the divide by D:
9986 // (I * B mod N) / D
9987 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
9988 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
9991 /// For a given quadratic addrec, generate coefficients of the corresponding
9992 /// quadratic equation, multiplied by a common value to ensure that they are
9994 /// The returned value is a tuple { A, B, C, M, BitWidth }, where
9995 /// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
9996 /// were multiplied by, and BitWidth is the bit width of the original addrec
9998 /// This function returns std::nullopt if the addrec coefficients are not
9999 /// compile- time constants.
10000 static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10001 GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
10002 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10003 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10004 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10005 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10006 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10007 << *AddRec << '\n');
10009 // We currently can only solve this if the coefficients are constants.
10010 if (!LC || !MC || !NC) {
10011 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10012 return std::nullopt;
10015 APInt L = LC->getAPInt();
10016 APInt M = MC->getAPInt();
10017 APInt N = NC->getAPInt();
10018 assert(!N.isZero() && "This is not a quadratic addrec");
10020 unsigned BitWidth = LC->getAPInt().getBitWidth();
10021 unsigned NewWidth = BitWidth + 1;
10022 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10023 << BitWidth << '\n');
10024 // The sign-extension (as opposed to a zero-extension) here matches the
10025 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10026 N = N.sext(NewWidth);
10027 M = M.sext(NewWidth);
10028 L = L.sext(NewWidth);
10030 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10031 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10032 // L+M, L+2M+N, L+3M+3N, ...
10033 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10035 // The equation Acc = 0 is then
10036 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10037 // In a quadratic form it becomes:
10038 // N n^2 + (2M-N) n + 2L = 0.
10041 APInt B = 2 * M - A;
10043 APInt T = APInt(NewWidth, 2);
10044 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10045 << "x + " << C << ", coeff bw: " << NewWidth
10046 << ", multiplied by " << T << '\n');
10047 return std::make_tuple(A, B, C, T, BitWidth);
10050 /// Helper function to compare optional APInts:
10051 /// (a) if X and Y both exist, return min(X, Y),
10052 /// (b) if neither X nor Y exist, return std::nullopt,
10053 /// (c) if exactly one of X and Y exists, return that value.
10054 static std::optional<APInt> MinOptional(std::optional<APInt> X,
10055 std::optional<APInt> Y) {
10057 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10058 APInt XW = X->sext(W);
10059 APInt YW = Y->sext(W);
10060 return XW.slt(YW) ? *X : *Y;
10063 return std::nullopt;
10064 return X ? *X : *Y;
10067 /// Helper function to truncate an optional APInt to a given BitWidth.
10068 /// When solving addrec-related equations, it is preferable to return a value
10069 /// that has the same bit width as the original addrec's coefficients. If the
10070 /// solution fits in the original bit width, truncate it (except for i1).
10071 /// Returning a value of a different bit width may inhibit some optimizations.
10073 /// In general, a solution to a quadratic equation generated from an addrec
10074 /// may require BW+1 bits, where BW is the bit width of the addrec's
10075 /// coefficients. The reason is that the coefficients of the quadratic
10076 /// equation are BW+1 bits wide (to avoid truncation when converting from
10077 /// the addrec to the equation).
10078 static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10079 unsigned BitWidth) {
10081 return std::nullopt;
10082 unsigned W = X->getBitWidth();
10083 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10084 return X->trunc(BitWidth);
10088 /// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10089 /// iterations. The values L, M, N are assumed to be signed, and they
10090 /// should all have the same bit widths.
10091 /// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10092 /// where BW is the bit width of the addrec's coefficients.
10093 /// If the calculated value is a BW-bit integer (for BW > 1), it will be
10094 /// returned as such, otherwise the bit width of the returned value may
10095 /// be greater than BW.
10097 /// This function returns std::nullopt if
10098 /// (a) the addrec coefficients are not constant, or
10099 /// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10100 /// like x^2 = 5, no integer solutions exist, in other cases an integer
10101 /// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10102 static std::optional<APInt>
10103 SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
10106 auto T = GetQuadraticEquation(AddRec);
10108 return std::nullopt;
10110 std::tie(A, B, C, M, BitWidth) = *T;
10111 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10112 std::optional<APInt> X =
10113 APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth + 1);
10115 return std::nullopt;
10117 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10118 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10120 return std::nullopt;
10122 return TruncIfPossible(X, BitWidth);
10125 /// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10126 /// iterations. The values M, N are assumed to be signed, and they
10127 /// should all have the same bit widths.
10128 /// Find the least n such that c(n) does not belong to the given range,
10129 /// while c(n-1) does.
10131 /// This function returns std::nullopt if
10132 /// (a) the addrec coefficients are not constant, or
10133 /// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10134 /// bounds of the range.
10135 static std::optional<APInt>
10136 SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
10137 const ConstantRange &Range, ScalarEvolution &SE) {
10138 assert(AddRec->getOperand(0)->isZero() &&
10139 "Starting value of addrec should be 0");
10140 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10141 << Range << ", addrec " << *AddRec << '\n');
10142 // This case is handled in getNumIterationsInRange. Here we can assume that
10143 // we start in the range.
10144 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10145 "Addrec's initial value should be in range");
10149 auto T = GetQuadraticEquation(AddRec);
10151 return std::nullopt;
10153 // Be careful about the return value: there can be two reasons for not
10154 // returning an actual number. First, if no solutions to the equations
10155 // were found, and second, if the solutions don't leave the given range.
10156 // The first case means that the actual solution is "unknown", the second
10157 // means that it's known, but not valid. If the solution is unknown, we
10158 // cannot make any conclusions.
10159 // Return a pair: the optional solution and a flag indicating if the
10160 // solution was found.
10161 auto SolveForBoundary =
10162 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10163 // Solve for signed overflow and unsigned overflow, pick the lower
10165 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10166 << Bound << " (before multiplying by " << M << ")\n");
10167 Bound *= M; // The quadratic equation multiplier.
10169 std::optional<APInt> SO;
10170 if (BitWidth > 1) {
10171 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10172 "signed overflow\n");
10173 SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
10175 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10176 "unsigned overflow\n");
10177 std::optional<APInt> UO =
10178 APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth + 1);
10180 auto LeavesRange = [&] (const APInt &X) {
10181 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10182 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10183 if (Range.contains(V0->getValue()))
10185 // X should be at least 1, so X-1 is non-negative.
10186 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10187 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10188 if (Range.contains(V1->getValue()))
10193 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10194 // can be a solution, but the function failed to find it. We cannot treat it
10195 // as "no solution".
10197 return {std::nullopt, false};
10199 // Check the smaller value first to see if it leaves the range.
10200 // At this point, both SO and UO must have values.
10201 std::optional<APInt> Min = MinOptional(SO, UO);
10202 if (LeavesRange(*Min))
10203 return { Min, true };
10204 std::optional<APInt> Max = Min == SO ? UO : SO;
10205 if (LeavesRange(*Max))
10206 return { Max, true };
10208 // Solutions were found, but were eliminated, hence the "true".
10209 return {std::nullopt, true};
10212 std::tie(A, B, C, M, BitWidth) = *T;
10213 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10214 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10215 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10216 auto SL = SolveForBoundary(Lower);
10217 auto SU = SolveForBoundary(Upper);
10218 // If any of the solutions was unknown, no meaninigful conclusions can
10220 if (!SL.second || !SU.second)
10221 return std::nullopt;
10223 // Claim: The correct solution is not some value between Min and Max.
10225 // Justification: Assuming that Min and Max are different values, one of
10226 // them is when the first signed overflow happens, the other is when the
10227 // first unsigned overflow happens. Crossing the range boundary is only
10228 // possible via an overflow (treating 0 as a special case of it, modeling
10229 // an overflow as crossing k*2^W for some k).
10231 // The interesting case here is when Min was eliminated as an invalid
10232 // solution, but Max was not. The argument is that if there was another
10233 // overflow between Min and Max, it would also have been eliminated if
10234 // it was considered.
10236 // For a given boundary, it is possible to have two overflows of the same
10237 // type (signed/unsigned) without having the other type in between: this
10238 // can happen when the vertex of the parabola is between the iterations
10239 // corresponding to the overflows. This is only possible when the two
10240 // overflows cross k*2^W for the same k. In such case, if the second one
10241 // left the range (and was the first one to do so), the first overflow
10242 // would have to enter the range, which would mean that either we had left
10243 // the range before or that we started outside of it. Both of these cases
10244 // are contradictions.
10246 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10247 // solution is not some value between the Max for this boundary and the
10248 // Min of the other boundary.
10250 // Justification: Assume that we had such Max_A and Min_B corresponding
10251 // to range boundaries A and B and such that Max_A < Min_B. If there was
10252 // a solution between Max_A and Min_B, it would have to be caused by an
10253 // overflow corresponding to either A or B. It cannot correspond to B,
10254 // since Min_B is the first occurrence of such an overflow. If it
10255 // corresponded to A, it would have to be either a signed or an unsigned
10256 // overflow that is larger than both eliminated overflows for A. But
10257 // between the eliminated overflows and this overflow, the values would
10258 // cover the entire value space, thus crossing the other boundary, which
10259 // is a contradiction.
10261 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10264 ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10266 bool ControlsOnlyExit,
10267 bool AllowPredicates) {
10269 // This is only used for loops with a "x != y" exit test. The exit condition
10270 // is now expressed as a single expression, V = x-y. So the exit test is
10271 // effectively V != 0. We know and take advantage of the fact that this
10272 // expression only being used in a comparison by zero context.
10274 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
10275 // If the value is a constant
10276 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10277 // If the value is already zero, the branch will execute zero times.
10278 if (C->getValue()->isZero()) return C;
10279 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10282 const SCEVAddRecExpr *AddRec =
10283 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10285 if (!AddRec && AllowPredicates)
10286 // Try to make this an AddRec using runtime tests, in the first X
10287 // iterations of this loop, where X is the SCEV expression found by the
10288 // algorithm below.
10289 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10291 if (!AddRec || AddRec->getLoop() != L)
10292 return getCouldNotCompute();
10294 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10295 // the quadratic equation to solve it.
10296 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10297 // We can only use this value if the chrec ends up with an exact zero
10298 // value at this index. When solving for "X*X != 5", for example, we
10299 // should not accept a root of 2.
10300 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10301 const auto *R = cast<SCEVConstant>(getConstant(*S));
10302 return ExitLimit(R, R, R, false, Predicates);
10304 return getCouldNotCompute();
10307 // Otherwise we can only handle this if it is affine.
10308 if (!AddRec->isAffine())
10309 return getCouldNotCompute();
10311 // If this is an affine expression, the execution count of this branch is
10312 // the minimum unsigned root of the following equation:
10314 // Start + Step*N = 0 (mod 2^BW)
10318 // Step*N = -Start (mod 2^BW)
10320 // where BW is the common bit width of Start and Step.
10322 // Get the initial value for the loop.
10323 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10324 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10326 // For now we handle only constant steps.
10328 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
10329 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
10330 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
10331 // We have not yet seen any such cases.
10332 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10333 if (!StepC || StepC->getValue()->isZero())
10334 return getCouldNotCompute();
10336 // For positive steps (counting up until unsigned overflow):
10337 // N = -Start/Step (as unsigned)
10338 // For negative steps (counting down to zero):
10340 // First compute the unsigned distance from zero in the direction of Step.
10341 bool CountDown = StepC->getAPInt().isNegative();
10342 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10344 // Handle unitary steps, which cannot wraparound.
10345 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10346 // N = Distance (as unsigned)
10347 if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
10348 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10349 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10351 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10352 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10353 // case, and see if we can improve the bound.
10355 // Explicitly handling this here is necessary because getUnsignedRange
10356 // isn't context-sensitive; it doesn't know that we only care about the
10357 // range inside the loop.
10358 const SCEV *Zero = getZero(Distance->getType());
10359 const SCEV *One = getOne(Distance->getType());
10360 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10361 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10362 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10363 // as "unsigned_max(Distance + 1) - 1".
10364 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10365 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10367 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10371 // If the condition controls loop exit (the loop exits only if the expression
10372 // is true) and the addition is no-wrap we can use unsigned divide to
10373 // compute the backedge count. In this case, the step may not divide the
10374 // distance, but we don't care because if the condition is "missed" the loop
10375 // will have undefined behavior due to wrapping.
10376 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10377 loopHasNoAbnormalExits(AddRec->getLoop())) {
10378 const SCEV *Exact =
10379 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10380 const SCEV *ConstantMax = getCouldNotCompute();
10381 if (Exact != getCouldNotCompute()) {
10382 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
10384 getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
10386 const SCEV *SymbolicMax =
10387 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10388 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10391 // Solve the general equation.
10392 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10393 getNegativeSCEV(Start), *this);
10396 if (E != getCouldNotCompute()) {
10397 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10398 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10400 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10401 return ExitLimit(E, M, S, false, Predicates);
10404 ScalarEvolution::ExitLimit
10405 ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10406 // Loops that look like: while (X == 0) are very strange indeed. We don't
10407 // handle them yet except for the trivial case. This could be expanded in the
10408 // future as needed.
10410 // If the value is a constant, check to see if it is known to be non-zero
10411 // already. If so, the backedge will execute zero times.
10412 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10413 if (!C->getValue()->isZero())
10414 return getZero(C->getType());
10415 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10418 // We could implement others, but I really doubt anyone writes loops like
10419 // this, and if they did, they would already be constant folded.
10420 return getCouldNotCompute();
10423 std::pair<const BasicBlock *, const BasicBlock *>
10424 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10426 // If the block has a unique predecessor, then there is no path from the
10427 // predecessor to the block that does not go through the direct edge
10428 // from the predecessor to the block.
10429 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10432 // A loop's header is defined to be a block that dominates the loop.
10433 // If the header has a unique predecessor outside the loop, it must be
10434 // a block that has exactly one successor that can reach the loop.
10435 if (const Loop *L = LI.getLoopFor(BB))
10436 return {L->getLoopPredecessor(), L->getHeader()};
10438 return {nullptr, nullptr};
10441 /// SCEV structural equivalence is usually sufficient for testing whether two
10442 /// expressions are equal, however for the purposes of looking for a condition
10443 /// guarding a loop, it can be useful to be a little more general, since a
10444 /// front-end may have replicated the controlling expression.
10445 static bool HasSameValue(const SCEV *A, const SCEV *B) {
10446 // Quick check to see if they are the same SCEV.
10447 if (A == B) return true;
10449 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10450 // Not all instructions that are "identical" compute the same value. For
10451 // instance, two distinct alloca instructions allocating the same type are
10452 // identical and do not read memory; but compute distinct values.
10453 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10456 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10457 // two different instructions with the same value. Check for this case.
10458 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10459 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10460 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10461 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10462 if (ComputesEqualValues(AI, BI))
10465 // Otherwise assume they may have a different value.
10469 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
10470 const SCEV *&LHS, const SCEV *&RHS,
10472 bool Changed = false;
10473 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10475 auto TrivialCase = [&](bool TriviallyTrue) {
10476 LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
10477 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10480 // If we hit the max recursion limit bail out.
10484 // Canonicalize a constant to the right side.
10485 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10486 // Check for both operands constant.
10487 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10488 if (ConstantExpr::getICmp(Pred,
10490 RHSC->getValue())->isNullValue())
10491 return TrivialCase(false);
10492 return TrivialCase(true);
10494 // Otherwise swap the operands to put the constant on the right.
10495 std::swap(LHS, RHS);
10496 Pred = ICmpInst::getSwappedPredicate(Pred);
10500 // If we're comparing an addrec with a value which is loop-invariant in the
10501 // addrec's loop, put the addrec on the left. Also make a dominance check,
10502 // as both operands could be addrecs loop-invariant in each other's loop.
10503 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10504 const Loop *L = AR->getLoop();
10505 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10506 std::swap(LHS, RHS);
10507 Pred = ICmpInst::getSwappedPredicate(Pred);
10512 // If there's a constant operand, canonicalize comparisons with boundary
10513 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10514 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10515 const APInt &RA = RC->getAPInt();
10517 bool SimplifiedByConstantRange = false;
10519 if (!ICmpInst::isEquality(Pred)) {
10520 ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
10521 if (ExactCR.isFullSet())
10522 return TrivialCase(true);
10523 if (ExactCR.isEmptySet())
10524 return TrivialCase(false);
10527 CmpInst::Predicate NewPred;
10528 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10529 ICmpInst::isEquality(NewPred)) {
10530 // We were able to convert an inequality to an equality.
10532 RHS = getConstant(NewRHS);
10533 Changed = SimplifiedByConstantRange = true;
10537 if (!SimplifiedByConstantRange) {
10541 case ICmpInst::ICMP_EQ:
10542 case ICmpInst::ICMP_NE:
10543 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10545 if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
10546 if (const SCEVMulExpr *ME =
10547 dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
10548 if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
10549 ME->getOperand(0)->isAllOnesValue()) {
10550 RHS = AE->getOperand(1);
10551 LHS = ME->getOperand(1);
10557 // The "Should have been caught earlier!" messages refer to the fact
10558 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10559 // should have fired on the corresponding cases, and canonicalized the
10560 // check to trivial case.
10562 case ICmpInst::ICMP_UGE:
10563 assert(!RA.isMinValue() && "Should have been caught earlier!");
10564 Pred = ICmpInst::ICMP_UGT;
10565 RHS = getConstant(RA - 1);
10568 case ICmpInst::ICMP_ULE:
10569 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10570 Pred = ICmpInst::ICMP_ULT;
10571 RHS = getConstant(RA + 1);
10574 case ICmpInst::ICMP_SGE:
10575 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10576 Pred = ICmpInst::ICMP_SGT;
10577 RHS = getConstant(RA - 1);
10580 case ICmpInst::ICMP_SLE:
10581 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10582 Pred = ICmpInst::ICMP_SLT;
10583 RHS = getConstant(RA + 1);
10590 // Check for obvious equality.
10591 if (HasSameValue(LHS, RHS)) {
10592 if (ICmpInst::isTrueWhenEqual(Pred))
10593 return TrivialCase(true);
10594 if (ICmpInst::isFalseWhenEqual(Pred))
10595 return TrivialCase(false);
10598 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10599 // adding or subtracting 1 from one of the operands.
10601 case ICmpInst::ICMP_SLE:
10602 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10603 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10605 Pred = ICmpInst::ICMP_SLT;
10607 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10608 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10610 Pred = ICmpInst::ICMP_SLT;
10614 case ICmpInst::ICMP_SGE:
10615 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10616 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10618 Pred = ICmpInst::ICMP_SGT;
10620 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10621 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10623 Pred = ICmpInst::ICMP_SGT;
10627 case ICmpInst::ICMP_ULE:
10628 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10629 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10631 Pred = ICmpInst::ICMP_ULT;
10633 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10634 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10635 Pred = ICmpInst::ICMP_ULT;
10639 case ICmpInst::ICMP_UGE:
10640 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10641 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10642 Pred = ICmpInst::ICMP_UGT;
10644 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10645 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10647 Pred = ICmpInst::ICMP_UGT;
10655 // TODO: More simplifications are possible here.
10657 // Recursively simplify until we either hit a recursion limit or nothing
10660 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10665 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
10666 return getSignedRangeMax(S).isNegative();
10669 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
10670 return getSignedRangeMin(S).isStrictlyPositive();
10673 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
10674 return !getSignedRangeMin(S).isNegative();
10677 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
10678 return !getSignedRangeMax(S).isStrictlyPositive();
10681 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
10682 return getUnsignedRangeMin(S) != 0;
10685 std::pair<const SCEV *, const SCEV *>
10686 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
10687 // Compute SCEV on entry of loop L.
10688 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10689 if (Start == getCouldNotCompute())
10690 return { Start, Start };
10691 // Compute post increment SCEV for loop L.
10692 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10693 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10694 return { Start, PostInc };
10697 bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
10698 const SCEV *LHS, const SCEV *RHS) {
10699 // First collect all loops.
10700 SmallPtrSet<const Loop *, 8> LoopsUsed;
10701 getUsedLoops(LHS, LoopsUsed);
10702 getUsedLoops(RHS, LoopsUsed);
10704 if (LoopsUsed.empty())
10707 // Domination relationship must be a linear order on collected loops.
10709 for (const auto *L1 : LoopsUsed)
10710 for (const auto *L2 : LoopsUsed)
10711 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10712 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10713 "Domination relationship is not a linear order");
10717 *std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
10718 [&](const Loop *L1, const Loop *L2) {
10719 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10722 // Get init and post increment value for LHS.
10723 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10724 // if LHS contains unknown non-invariant SCEV then bail out.
10725 if (SplitLHS.first == getCouldNotCompute())
10727 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10728 // Get init and post increment value for RHS.
10729 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10730 // if RHS contains unknown non-invariant SCEV then bail out.
10731 if (SplitRHS.first == getCouldNotCompute())
10733 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10734 // It is possible that init SCEV contains an invariant load but it does
10735 // not dominate MDL and is not available at MDL loop entry, so we should
10737 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10738 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10741 // It seems backedge guard check is faster than entry one so in some cases
10742 // it can speed up whole estimation by short circuit
10743 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10744 SplitRHS.second) &&
10745 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10748 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
10749 const SCEV *LHS, const SCEV *RHS) {
10750 // Canonicalize the inputs first.
10751 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10753 if (isKnownViaInduction(Pred, LHS, RHS))
10756 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10759 // Otherwise see what can be done with some simple reasoning.
10760 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10763 std::optional<bool> ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred,
10766 if (isKnownPredicate(Pred, LHS, RHS))
10768 if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS))
10770 return std::nullopt;
10773 bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred,
10774 const SCEV *LHS, const SCEV *RHS,
10775 const Instruction *CtxI) {
10776 // TODO: Analyze guards and assumes from Context's block.
10777 return isKnownPredicate(Pred, LHS, RHS) ||
10778 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
10781 std::optional<bool>
10782 ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS,
10783 const SCEV *RHS, const Instruction *CtxI) {
10784 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10785 if (KnownWithoutContext)
10786 return KnownWithoutContext;
10788 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10790 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(),
10791 ICmpInst::getInversePredicate(Pred),
10794 return std::nullopt;
10797 bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
10798 const SCEVAddRecExpr *LHS,
10800 const Loop *L = LHS->getLoop();
10801 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10802 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10805 std::optional<ScalarEvolution::MonotonicPredicateType>
10806 ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS,
10807 ICmpInst::Predicate Pred) {
10808 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10811 // Verify an invariant: inverting the predicate should turn a monotonically
10812 // increasing change to a monotonically decreasing one, and vice versa.
10814 auto ResultSwapped =
10815 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10817 assert(*ResultSwapped != *Result &&
10818 "monotonicity should flip as we flip the predicate");
10825 std::optional<ScalarEvolution::MonotonicPredicateType>
10826 ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
10827 ICmpInst::Predicate Pred) {
10828 // A zero step value for LHS means the induction variable is essentially a
10829 // loop invariant value. We don't really depend on the predicate actually
10830 // flipping from false to true (for increasing predicates, and the other way
10831 // around for decreasing predicates), all we care about is that *if* the
10832 // predicate changes then it only changes from false to true.
10834 // A zero step value in itself is not very useful, but there may be places
10835 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
10836 // as general as possible.
10838 // Only handle LE/LT/GE/GT predicates.
10839 if (!ICmpInst::isRelational(Pred))
10840 return std::nullopt;
10842 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
10843 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
10844 "Should be greater or less!");
10846 // Check that AR does not wrap.
10847 if (ICmpInst::isUnsigned(Pred)) {
10848 if (!LHS->hasNoUnsignedWrap())
10849 return std::nullopt;
10850 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10852 assert(ICmpInst::isSigned(Pred) &&
10853 "Relational predicate is either signed or unsigned!");
10854 if (!LHS->hasNoSignedWrap())
10855 return std::nullopt;
10857 const SCEV *Step = LHS->getStepRecurrence(*this);
10859 if (isKnownNonNegative(Step))
10860 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10862 if (isKnownNonPositive(Step))
10863 return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10865 return std::nullopt;
10868 std::optional<ScalarEvolution::LoopInvariantPredicate>
10869 ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred,
10870 const SCEV *LHS, const SCEV *RHS,
10872 const Instruction *CtxI) {
10873 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10874 if (!isLoopInvariant(RHS, L)) {
10875 if (!isLoopInvariant(LHS, L))
10876 return std::nullopt;
10878 std::swap(LHS, RHS);
10879 Pred = ICmpInst::getSwappedPredicate(Pred);
10882 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
10883 if (!ArLHS || ArLHS->getLoop() != L)
10884 return std::nullopt;
10886 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
10887 if (!MonotonicType)
10888 return std::nullopt;
10889 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
10890 // true as the loop iterates, and the backedge is control dependent on
10891 // "ArLHS `Pred` RHS" == true then we can reason as follows:
10893 // * if the predicate was false in the first iteration then the predicate
10894 // is never evaluated again, since the loop exits without taking the
10896 // * if the predicate was true in the first iteration then it will
10897 // continue to be true for all future iterations since it is
10898 // monotonically increasing.
10900 // For both the above possibilities, we can replace the loop varying
10901 // predicate with its value on the first iteration of the loop (which is
10902 // loop invariant).
10904 // A similar reasoning applies for a monotonically decreasing predicate, by
10905 // replacing true with false and false with true in the above two bullets.
10906 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
10907 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
10909 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
10910 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(),
10914 return std::nullopt;
10915 // Try to prove via context.
10916 // TODO: Support other cases.
10920 case ICmpInst::ICMP_ULE:
10921 case ICmpInst::ICMP_ULT: {
10922 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
10923 // Given preconditions
10924 // (1) ArLHS does not cross the border of positive and negative parts of
10925 // range because of:
10926 // - Positive step; (TODO: lift this limitation)
10927 // - nuw - does not cross zero boundary;
10928 // - nsw - does not cross SINT_MAX boundary;
10929 // (2) ArLHS <s RHS
10931 // we can replace the loop variant ArLHS <u RHS condition with loop
10932 // invariant Start(ArLHS) <u RHS.
10934 // Because of (1) there are two options:
10935 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
10936 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
10937 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
10938 // Because of (2) ArLHS <u RHS is trivially true.
10939 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
10940 // We can strengthen this to Start(ArLHS) <u RHS.
10941 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
10942 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
10943 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
10944 isKnownNonNegative(RHS) &&
10945 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
10946 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(),
10951 return std::nullopt;
10954 std::optional<ScalarEvolution::LoopInvariantPredicate>
10955 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
10956 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
10957 const Instruction *CtxI, const SCEV *MaxIter) {
10958 if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl(
10959 Pred, LHS, RHS, L, CtxI, MaxIter))
10961 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
10962 // Number of iterations expressed as UMIN isn't always great for expressing
10963 // the value on the last iteration. If the straightforward approach didn't
10964 // work, try the following trick: if the a predicate is invariant for X, it
10965 // is also invariant for umin(X, ...). So try to find something that works
10966 // among subexpressions of MaxIter expressed as umin.
10967 for (auto *Op : UMin->operands())
10968 if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl(
10969 Pred, LHS, RHS, L, CtxI, Op))
10971 return std::nullopt;
10974 std::optional<ScalarEvolution::LoopInvariantPredicate>
10975 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl(
10976 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
10977 const Instruction *CtxI, const SCEV *MaxIter) {
10978 // Try to prove the following set of facts:
10979 // - The predicate is monotonic in the iteration space.
10980 // - If the check does not fail on the 1st iteration:
10981 // - No overflow will happen during first MaxIter iterations;
10982 // - It will not fail on the MaxIter'th iteration.
10983 // If the check does fail on the 1st iteration, we leave the loop and no
10984 // other checks matter.
10986 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10987 if (!isLoopInvariant(RHS, L)) {
10988 if (!isLoopInvariant(LHS, L))
10989 return std::nullopt;
10991 std::swap(LHS, RHS);
10992 Pred = ICmpInst::getSwappedPredicate(Pred);
10995 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
10996 if (!AR || AR->getLoop() != L)
10997 return std::nullopt;
10999 // The predicate must be relational (i.e. <, <=, >=, >).
11000 if (!ICmpInst::isRelational(Pred))
11001 return std::nullopt;
11003 // TODO: Support steps other than +/- 1.
11004 const SCEV *Step = AR->getStepRecurrence(*this);
11005 auto *One = getOne(Step->getType());
11006 auto *MinusOne = getNegativeSCEV(One);
11007 if (Step != One && Step != MinusOne)
11008 return std::nullopt;
11010 // Type mismatch here means that MaxIter is potentially larger than max
11011 // unsigned value in start type, which mean we cannot prove no wrap for the
11013 if (AR->getType() != MaxIter->getType())
11014 return std::nullopt;
11016 // Value of IV on suggested last iteration.
11017 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11018 // Does it still meet the requirement?
11019 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11020 return std::nullopt;
11021 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11022 // not exceed max unsigned value of this type), this effectively proves
11023 // that there is no wrap during the iteration. To prove that there is no
11024 // signed/unsigned wrap, we need to check that
11025 // Start <= Last for step = 1 or Start >= Last for step = -1.
11026 ICmpInst::Predicate NoOverflowPred =
11027 CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
11028 if (Step == MinusOne)
11029 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11030 const SCEV *Start = AR->getStart();
11031 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11032 return std::nullopt;
11034 // Everything is fine.
11035 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11038 bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11039 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11040 if (HasSameValue(LHS, RHS))
11041 return ICmpInst::isTrueWhenEqual(Pred);
11043 // This code is split out from isKnownPredicate because it is called from
11044 // within isLoopEntryGuardedByCond.
11046 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11047 const ConstantRange &RangeRHS) {
11048 return RangeLHS.icmp(Pred, RangeRHS);
11051 // The check at the top of the function catches the case where the values are
11052 // known to be equal.
11053 if (Pred == CmpInst::ICMP_EQ)
11056 if (Pred == CmpInst::ICMP_NE) {
11057 auto SL = getSignedRange(LHS);
11058 auto SR = getSignedRange(RHS);
11059 if (CheckRanges(SL, SR))
11061 auto UL = getUnsignedRange(LHS);
11062 auto UR = getUnsignedRange(RHS);
11063 if (CheckRanges(UL, UR))
11065 auto *Diff = getMinusSCEV(LHS, RHS);
11066 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11069 if (CmpInst::isSigned(Pred)) {
11070 auto SL = getSignedRange(LHS);
11071 auto SR = getSignedRange(RHS);
11072 return CheckRanges(SL, SR);
11075 auto UL = getUnsignedRange(LHS);
11076 auto UR = getUnsignedRange(RHS);
11077 return CheckRanges(UL, UR);
11080 bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11083 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11084 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11085 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11086 // OutC1 and OutC2.
11087 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11088 APInt &OutC1, APInt &OutC2,
11089 SCEV::NoWrapFlags ExpectedFlags) {
11090 const SCEV *XNonConstOp, *XConstOp;
11091 const SCEV *YNonConstOp, *YConstOp;
11092 SCEV::NoWrapFlags XFlagsPresent;
11093 SCEV::NoWrapFlags YFlagsPresent;
11095 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11096 XConstOp = getZero(X->getType());
11098 XFlagsPresent = ExpectedFlags;
11100 if (!isa<SCEVConstant>(XConstOp) ||
11101 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11104 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11105 YConstOp = getZero(Y->getType());
11107 YFlagsPresent = ExpectedFlags;
11110 if (!isa<SCEVConstant>(YConstOp) ||
11111 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11114 if (YNonConstOp != XNonConstOp)
11117 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11118 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11130 case ICmpInst::ICMP_SGE:
11131 std::swap(LHS, RHS);
11133 case ICmpInst::ICMP_SLE:
11134 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11135 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11140 case ICmpInst::ICMP_SGT:
11141 std::swap(LHS, RHS);
11143 case ICmpInst::ICMP_SLT:
11144 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11145 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11150 case ICmpInst::ICMP_UGE:
11151 std::swap(LHS, RHS);
11153 case ICmpInst::ICMP_ULE:
11154 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11155 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
11160 case ICmpInst::ICMP_UGT:
11161 std::swap(LHS, RHS);
11163 case ICmpInst::ICMP_ULT:
11164 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11165 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
11173 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11176 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11179 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11180 // the stack can result in exponential time complexity.
11181 SaveAndRestore Restore(ProvingSplitPredicate, true);
11183 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11185 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11186 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11187 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11188 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11189 // use isKnownPredicate later if needed.
11190 return isKnownNonNegative(RHS) &&
11191 isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
11192 isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
11195 bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11196 ICmpInst::Predicate Pred,
11197 const SCEV *LHS, const SCEV *RHS) {
11198 // No need to even try if we know the module has no guards.
11202 return any_of(*BB, [&](const Instruction &I) {
11203 using namespace llvm::PatternMatch;
11206 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11207 m_Value(Condition))) &&
11208 isImpliedCond(Pred, LHS, RHS, Condition, false);
11212 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11213 /// protected by a conditional between LHS and RHS. This is used to
11214 /// to eliminate casts.
11216 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
11217 ICmpInst::Predicate Pred,
11218 const SCEV *LHS, const SCEV *RHS) {
11219 // Interpret a null as meaning no loop, where there is obviously no guard
11220 // (interprocedural conditions notwithstanding). Do not bother about
11221 // unreachable loops.
11222 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11226 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11227 "This cannot be done on broken IR!");
11230 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11233 BasicBlock *Latch = L->getLoopLatch();
11237 BranchInst *LoopContinuePredicate =
11238 dyn_cast<BranchInst>(Latch->getTerminator());
11239 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11240 isImpliedCond(Pred, LHS, RHS,
11241 LoopContinuePredicate->getCondition(),
11242 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11245 // We don't want more than one activation of the following loops on the stack
11246 // -- that can lead to O(n!) time complexity.
11247 if (WalkingBEDominatingConds)
11250 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11252 // See if we can exploit a trip count to prove the predicate.
11253 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11254 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11255 if (LatchBECount != getCouldNotCompute()) {
11256 // We know that Latch branches back to the loop header exactly
11257 // LatchBECount times. This means the backdege condition at Latch is
11258 // equivalent to "{0,+,1} u< LatchBECount".
11259 Type *Ty = LatchBECount->getType();
11260 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11261 const SCEV *LoopCounter =
11262 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11263 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11268 // Check conditions due to any @llvm.assume intrinsics.
11269 for (auto &AssumeVH : AC.assumptions()) {
11272 auto *CI = cast<CallInst>(AssumeVH);
11273 if (!DT.dominates(CI, Latch->getTerminator()))
11276 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11280 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11283 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11284 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11285 assert(DTN && "should reach the loop header before reaching the root!");
11287 BasicBlock *BB = DTN->getBlock();
11288 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11291 BasicBlock *PBB = BB->getSinglePredecessor();
11295 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11296 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11299 Value *Condition = ContinuePredicate->getCondition();
11301 // If we have an edge `E` within the loop body that dominates the only
11302 // latch, the condition guarding `E` also guards the backedge. This
11303 // reasoning works only for loops with a single latch.
11305 BasicBlockEdge DominatingEdge(PBB, BB);
11306 if (DominatingEdge.isSingleEdge()) {
11307 // We're constructively (and conservatively) enumerating edges within the
11308 // loop body that dominate the latch. The dominator tree better agree
11309 // with us on this:
11310 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11312 if (isImpliedCond(Pred, LHS, RHS, Condition,
11313 BB != ContinuePredicate->getSuccessor(0)))
11321 bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
11322 ICmpInst::Predicate Pred,
11325 // Do not bother proving facts for unreachable code.
11326 if (!DT.isReachableFromEntry(BB))
11329 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11330 "This cannot be done on broken IR!");
11332 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11333 // the facts (a >= b && a != b) separately. A typical situation is when the
11334 // non-strict comparison is known from ranges and non-equality is known from
11335 // dominating predicates. If we are proving strict comparison, we always try
11336 // to prove non-equality and non-strict comparison separately.
11337 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11338 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11339 bool ProvedNonStrictComparison = false;
11340 bool ProvedNonEquality = false;
11342 auto SplitAndProve =
11343 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11344 if (!ProvedNonStrictComparison)
11345 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11346 if (!ProvedNonEquality)
11347 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11348 if (ProvedNonStrictComparison && ProvedNonEquality)
11353 if (ProvingStrictComparison) {
11354 auto ProofFn = [&](ICmpInst::Predicate P) {
11355 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11357 if (SplitAndProve(ProofFn))
11361 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11362 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11363 const Instruction *CtxI = &BB->front();
11364 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11366 if (ProvingStrictComparison) {
11367 auto ProofFn = [&](ICmpInst::Predicate P) {
11368 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11370 if (SplitAndProve(ProofFn))
11376 // Starting at the block's predecessor, climb up the predecessor chain, as long
11377 // as there are predecessors that can be found that have unique successors
11378 // leading to the original block.
11379 const Loop *ContainingLoop = LI.getLoopFor(BB);
11380 const BasicBlock *PredBB;
11381 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11382 PredBB = ContainingLoop->getLoopPredecessor();
11384 PredBB = BB->getSinglePredecessor();
11385 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11386 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11387 const BranchInst *BlockEntryPredicate =
11388 dyn_cast<BranchInst>(Pair.first->getTerminator());
11389 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11392 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11393 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11397 // Check conditions due to any @llvm.assume intrinsics.
11398 for (auto &AssumeVH : AC.assumptions()) {
11401 auto *CI = cast<CallInst>(AssumeVH);
11402 if (!DT.dominates(CI, BB))
11405 if (ProveViaCond(CI->getArgOperand(0), false))
11409 // Check conditions due to any @llvm.experimental.guard intrinsics.
11410 auto *GuardDecl = F.getParent()->getFunction(
11411 Intrinsic::getName(Intrinsic::experimental_guard));
11413 for (const auto *GU : GuardDecl->users())
11414 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11415 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11416 if (ProveViaCond(Guard->getArgOperand(0), false))
11421 bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
11422 ICmpInst::Predicate Pred,
11425 // Interpret a null as meaning no loop, where there is obviously no guard
11426 // (interprocedural conditions notwithstanding).
11430 // Both LHS and RHS must be available at loop entry.
11431 assert(isAvailableAtLoopEntry(LHS, L) &&
11432 "LHS is not available at Loop Entry");
11433 assert(isAvailableAtLoopEntry(RHS, L) &&
11434 "RHS is not available at Loop Entry");
11436 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11439 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11442 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11444 const Value *FoundCondValue, bool Inverse,
11445 const Instruction *CtxI) {
11446 // False conditions implies anything. Do not bother analyzing it further.
11447 if (FoundCondValue ==
11448 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11451 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11455 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11457 // Recursively handle And and Or conditions.
11458 const Value *Op0, *Op1;
11459 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11461 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11462 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11463 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11465 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11466 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11469 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11470 if (!ICI) return false;
11472 // Now that we found a conditional branch that dominates the loop or controls
11473 // the loop latch. Check to see if it is the comparison we are looking for.
11474 ICmpInst::Predicate FoundPred;
11476 FoundPred = ICI->getInversePredicate();
11478 FoundPred = ICI->getPredicate();
11480 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11481 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11483 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11486 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11488 ICmpInst::Predicate FoundPred,
11489 const SCEV *FoundLHS, const SCEV *FoundRHS,
11490 const Instruction *CtxI) {
11491 // Balance the types.
11492 if (getTypeSizeInBits(LHS->getType()) <
11493 getTypeSizeInBits(FoundLHS->getType())) {
11494 // For unsigned and equality predicates, try to prove that both found
11495 // operands fit into narrow unsigned range. If so, try to prove facts in
11497 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11498 !FoundRHS->getType()->isPointerTy()) {
11499 auto *NarrowType = LHS->getType();
11500 auto *WideType = FoundLHS->getType();
11501 auto BitWidth = getTypeSizeInBits(NarrowType);
11502 const SCEV *MaxValue = getZeroExtendExpr(
11503 getConstant(APInt::getMaxValue(BitWidth)), WideType);
11504 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11506 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11508 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11509 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11510 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11511 TruncFoundRHS, CtxI))
11516 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11518 if (CmpInst::isSigned(Pred)) {
11519 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11520 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11522 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11523 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11525 } else if (getTypeSizeInBits(LHS->getType()) >
11526 getTypeSizeInBits(FoundLHS->getType())) {
11527 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11529 if (CmpInst::isSigned(FoundPred)) {
11530 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11531 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11533 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11534 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11537 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11541 bool ScalarEvolution::isImpliedCondBalancedTypes(
11542 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11543 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11544 const Instruction *CtxI) {
11545 assert(getTypeSizeInBits(LHS->getType()) ==
11546 getTypeSizeInBits(FoundLHS->getType()) &&
11547 "Types should be balanced!");
11548 // Canonicalize the query to match the way instcombine will have
11549 // canonicalized the comparison.
11550 if (SimplifyICmpOperands(Pred, LHS, RHS))
11552 return CmpInst::isTrueWhenEqual(Pred);
11553 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11554 if (FoundLHS == FoundRHS)
11555 return CmpInst::isFalseWhenEqual(FoundPred);
11557 // Check to see if we can make the LHS or RHS match.
11558 if (LHS == FoundRHS || RHS == FoundLHS) {
11559 if (isa<SCEVConstant>(RHS)) {
11560 std::swap(FoundLHS, FoundRHS);
11561 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11563 std::swap(LHS, RHS);
11564 Pred = ICmpInst::getSwappedPredicate(Pred);
11568 // Check whether the found predicate is the same as the desired predicate.
11569 if (FoundPred == Pred)
11570 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11572 // Check whether swapping the found predicate makes it the same as the
11573 // desired predicate.
11574 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11575 // We can write the implication
11576 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11577 // using one of the following ways:
11578 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11579 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11580 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11581 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11582 // Forms 1. and 2. require swapping the operands of one condition. Don't
11583 // do this if it would break canonical constant/addrec ordering.
11584 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11585 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11587 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11588 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11590 // There's no clear preference between forms 3. and 4., try both. Avoid
11591 // forming getNotSCEV of pointer values as the resulting subtract is
11593 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11594 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11595 FoundLHS, FoundRHS, CtxI))
11598 if (!FoundLHS->getType()->isPointerTy() &&
11599 !FoundRHS->getType()->isPointerTy() &&
11600 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11601 getNotSCEV(FoundRHS), CtxI))
11607 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11608 CmpInst::Predicate P2) {
11609 assert(P1 != P2 && "Handled earlier!");
11610 return CmpInst::isRelational(P2) &&
11611 P1 == CmpInst::getFlippedSignednessPredicate(P2);
11613 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11614 // Unsigned comparison is the same as signed comparison when both the
11615 // operands are non-negative or negative.
11616 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11617 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11618 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11619 // Create local copies that we can freely swap and canonicalize our
11620 // conditions to "le/lt".
11621 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11622 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11623 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11624 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11625 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11626 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11627 std::swap(CanonicalLHS, CanonicalRHS);
11628 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11630 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11632 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11633 ICmpInst::isLE(CanonicalFoundPred)) &&
11635 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11636 // Use implication:
11637 // x <u y && y >=s 0 --> x <s y.
11638 // If we can prove the left part, the right part is also proven.
11639 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11640 CanonicalRHS, CanonicalFoundLHS,
11641 CanonicalFoundRHS);
11642 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11643 // Use implication:
11644 // x <s y && y <s 0 --> x <u y.
11645 // If we can prove the left part, the right part is also proven.
11646 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11647 CanonicalRHS, CanonicalFoundLHS,
11648 CanonicalFoundRHS);
11651 // Check if we can make progress by sharpening ranges.
11652 if (FoundPred == ICmpInst::ICMP_NE &&
11653 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11655 const SCEVConstant *C = nullptr;
11656 const SCEV *V = nullptr;
11658 if (isa<SCEVConstant>(FoundLHS)) {
11659 C = cast<SCEVConstant>(FoundLHS);
11662 C = cast<SCEVConstant>(FoundRHS);
11666 // The guarding predicate tells us that C != V. If the known range
11667 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11668 // range we consider has to correspond to same signedness as the
11669 // predicate we're interested in folding.
11671 APInt Min = ICmpInst::isSigned(Pred) ?
11672 getSignedRangeMin(V) : getUnsignedRangeMin(V);
11674 if (Min == C->getAPInt()) {
11675 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11676 // This is true even if (Min + 1) wraps around -- in case of
11677 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11679 APInt SharperMin = Min + 1;
11682 case ICmpInst::ICMP_SGE:
11683 case ICmpInst::ICMP_UGE:
11684 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11685 // RHS, we're done.
11686 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11691 case ICmpInst::ICMP_SGT:
11692 case ICmpInst::ICMP_UGT:
11693 // We know from the range information that (V `Pred` Min ||
11694 // V == Min). We know from the guarding condition that !(V
11695 // == Min). This gives us
11697 // V `Pred` Min || V == Min && !(V == Min)
11700 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11702 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11706 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11707 case ICmpInst::ICMP_SLE:
11708 case ICmpInst::ICMP_ULE:
11709 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11710 LHS, V, getConstant(SharperMin), CtxI))
11714 case ICmpInst::ICMP_SLT:
11715 case ICmpInst::ICMP_ULT:
11716 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11717 LHS, V, getConstant(Min), CtxI))
11728 // Check whether the actual condition is beyond sufficient.
11729 if (FoundPred == ICmpInst::ICMP_EQ)
11730 if (ICmpInst::isTrueWhenEqual(Pred))
11731 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11733 if (Pred == ICmpInst::ICMP_NE)
11734 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11735 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11738 // Otherwise assume the worst.
11742 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11743 const SCEV *&L, const SCEV *&R,
11744 SCEV::NoWrapFlags &Flags) {
11745 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11746 if (!AE || AE->getNumOperands() != 2)
11749 L = AE->getOperand(0);
11750 R = AE->getOperand(1);
11751 Flags = AE->getNoWrapFlags();
11755 std::optional<APInt>
11756 ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
11757 // We avoid subtracting expressions here because this function is usually
11758 // fairly deep in the call stack (i.e. is called many times).
11762 return APInt(getTypeSizeInBits(More->getType()), 0);
11764 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11765 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11766 const auto *MAR = cast<SCEVAddRecExpr>(More);
11768 if (LAR->getLoop() != MAR->getLoop())
11769 return std::nullopt;
11771 // We look at affine expressions only; not for correctness but to keep
11772 // getStepRecurrence cheap.
11773 if (!LAR->isAffine() || !MAR->isAffine())
11774 return std::nullopt;
11776 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11777 return std::nullopt;
11779 Less = LAR->getStart();
11780 More = MAR->getStart();
11785 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11786 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11787 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11791 SCEV::NoWrapFlags Flags;
11792 const SCEV *LLess = nullptr, *RLess = nullptr;
11793 const SCEV *LMore = nullptr, *RMore = nullptr;
11794 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11795 // Compare (X + C1) vs X.
11796 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11797 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11799 return -(C1->getAPInt());
11801 // Compare X vs (X + C2).
11802 if (splitBinaryAdd(More, LMore, RMore, Flags))
11803 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11805 return C2->getAPInt();
11807 // Compare (X + C1) vs (X + C2).
11808 if (C1 && C2 && RLess == RMore)
11809 return C2->getAPInt() - C1->getAPInt();
11811 return std::nullopt;
11814 bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
11815 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11816 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11817 // Try to recognize the following pattern:
11822 // FoundLHS = {Start,+,W}
11823 // context_bb: // Basic block from the same loop
11824 // known(Pred, FoundLHS, FoundRHS)
11826 // If some predicate is known in the context of a loop, it is also known on
11827 // each iteration of this loop, including the first iteration. Therefore, in
11828 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
11829 // prove the original pred using this fact.
11832 const BasicBlock *ContextBB = CtxI->getParent();
11833 // Make sure AR varies in the context block.
11834 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
11835 const Loop *L = AR->getLoop();
11836 // Make sure that context belongs to the loop and executes on 1st iteration
11837 // (if it ever executes at all).
11838 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11840 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
11842 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
11845 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
11846 const Loop *L = AR->getLoop();
11847 // Make sure that context belongs to the loop and executes on 1st iteration
11848 // (if it ever executes at all).
11849 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11851 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
11853 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
11859 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
11860 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11861 const SCEV *FoundLHS, const SCEV *FoundRHS) {
11862 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
11865 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11869 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
11870 if (!AddRecFoundLHS)
11873 // We'd like to let SCEV reason about control dependencies, so we constrain
11874 // both the inequalities to be about add recurrences on the same loop. This
11875 // way we can use isLoopEntryGuardedByCond later.
11877 const Loop *L = AddRecFoundLHS->getLoop();
11878 if (L != AddRecLHS->getLoop())
11881 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
11883 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
11886 // Informal proof for (2), assuming (1) [*]:
11888 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
11892 // FoundLHS s< FoundRHS s< INT_MIN - C
11893 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
11894 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
11895 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
11896 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
11897 // <=> FoundLHS + C s< FoundRHS + C
11899 // [*]: (1) can be proved by ruling out overflow.
11901 // [**]: This can be proved by analyzing all the four possibilities:
11902 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
11903 // (A s>= 0, B s>= 0).
11906 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
11907 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
11908 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
11909 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
11910 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
11913 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
11914 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
11915 if (!LDiff || !RDiff || *LDiff != *RDiff)
11918 if (LDiff->isMinValue())
11921 APInt FoundRHSLimit;
11923 if (Pred == CmpInst::ICMP_ULT) {
11924 FoundRHSLimit = -(*RDiff);
11926 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
11927 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
11930 // Try to prove (1) or (2), as needed.
11931 return isAvailableAtLoopEntry(FoundRHS, L) &&
11932 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
11933 getConstant(FoundRHSLimit));
11936 bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
11937 const SCEV *LHS, const SCEV *RHS,
11938 const SCEV *FoundLHS,
11939 const SCEV *FoundRHS, unsigned Depth) {
11940 const PHINode *LPhi = nullptr, *RPhi = nullptr;
11942 auto ClearOnExit = make_scope_exit([&]() {
11944 bool Erased = PendingMerges.erase(LPhi);
11945 assert(Erased && "Failed to erase LPhi!");
11949 bool Erased = PendingMerges.erase(RPhi);
11950 assert(Erased && "Failed to erase RPhi!");
11955 // Find respective Phis and check that they are not being pending.
11956 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
11957 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
11958 if (!PendingMerges.insert(Phi).second)
11962 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
11963 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
11964 // If we detect a loop of Phi nodes being processed by this method, for
11967 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
11968 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
11970 // we don't want to deal with a case that complex, so return conservative
11972 if (!PendingMerges.insert(Phi).second)
11977 // If none of LHS, RHS is a Phi, nothing to do here.
11978 if (!LPhi && !RPhi)
11981 // If there is a SCEVUnknown Phi we are interested in, make it left.
11983 std::swap(LHS, RHS);
11984 std::swap(FoundLHS, FoundRHS);
11985 std::swap(LPhi, RPhi);
11986 Pred = ICmpInst::getSwappedPredicate(Pred);
11989 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
11990 const BasicBlock *LBB = LPhi->getParent();
11991 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
11993 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
11994 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
11995 isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
11996 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
11999 if (RPhi && RPhi->getParent() == LBB) {
12000 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12001 // If we compare two Phis from the same block, and for each entry block
12002 // the predicate is true for incoming values from this block, then the
12003 // predicate is also true for the Phis.
12004 for (const BasicBlock *IncBB : predecessors(LBB)) {
12005 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12006 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12007 if (!ProvedEasily(L, R))
12010 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12011 // Case two: RHS is also a Phi from the same basic block, and it is an
12012 // AddRec. It means that there is a loop which has both AddRec and Unknown
12013 // PHIs, for it we can compare incoming values of AddRec from above the loop
12014 // and latch with their respective incoming values of LPhi.
12015 // TODO: Generalize to handle loops with many inputs in a header.
12016 if (LPhi->getNumIncomingValues() != 2) return false;
12018 auto *RLoop = RAR->getLoop();
12019 auto *Predecessor = RLoop->getLoopPredecessor();
12020 assert(Predecessor && "Loop with AddRec with no predecessor?");
12021 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12022 if (!ProvedEasily(L1, RAR->getStart()))
12024 auto *Latch = RLoop->getLoopLatch();
12025 assert(Latch && "Loop with AddRec with no latch?");
12026 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12027 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12030 // In all other cases go over inputs of LHS and compare each of them to RHS,
12031 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12032 // At this point RHS is either a non-Phi, or it is a Phi from some block
12033 // different from LBB.
12034 for (const BasicBlock *IncBB : predecessors(LBB)) {
12035 // Check that RHS is available in this block.
12036 if (!dominates(RHS, IncBB))
12038 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12039 // Make sure L does not refer to a value from a potentially previous
12040 // iteration of a loop.
12041 if (!properlyDominates(L, LBB))
12043 if (!ProvedEasily(L, RHS))
12050 bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12053 const SCEV *FoundLHS,
12054 const SCEV *FoundRHS) {
12055 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12056 // sure that we are dealing with same LHS.
12057 if (RHS == FoundRHS) {
12058 std::swap(LHS, RHS);
12059 std::swap(FoundLHS, FoundRHS);
12060 Pred = ICmpInst::getSwappedPredicate(Pred);
12062 if (LHS != FoundLHS)
12065 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12069 Value *Shiftee, *ShiftValue;
12071 using namespace PatternMatch;
12072 if (match(SUFoundRHS->getValue(),
12073 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12074 auto *ShifteeS = getSCEV(Shiftee);
12075 // Prove one of the following:
12076 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12077 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12078 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12080 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12081 // ---> LHS <=s RHS
12082 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12083 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12084 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12085 if (isKnownNonNegative(ShifteeS))
12086 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12092 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12093 const SCEV *LHS, const SCEV *RHS,
12094 const SCEV *FoundLHS,
12095 const SCEV *FoundRHS,
12096 const Instruction *CtxI) {
12097 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
12100 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12103 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12106 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12110 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12111 FoundLHS, FoundRHS);
12114 /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12115 template <typename MinMaxExprType>
12116 static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12117 const SCEV *Candidate) {
12118 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12122 return is_contained(MinMaxExpr->operands(), Candidate);
12125 static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
12126 ICmpInst::Predicate Pred,
12127 const SCEV *LHS, const SCEV *RHS) {
12128 // If both sides are affine addrecs for the same loop, with equal
12129 // steps, and we know the recurrences don't wrap, then we only
12130 // need to check the predicate on the starting values.
12132 if (!ICmpInst::isRelational(Pred))
12135 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12138 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12141 if (LAR->getLoop() != RAR->getLoop())
12143 if (!LAR->isAffine() || !RAR->isAffine())
12146 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12149 SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
12150 SCEV::FlagNSW : SCEV::FlagNUW;
12151 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12154 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12157 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12159 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
12160 ICmpInst::Predicate Pred,
12161 const SCEV *LHS, const SCEV *RHS) {
12166 case ICmpInst::ICMP_SGE:
12167 std::swap(LHS, RHS);
12169 case ICmpInst::ICMP_SLE:
12171 // min(A, ...) <= A
12172 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12173 // A <= max(A, ...)
12174 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12176 case ICmpInst::ICMP_UGE:
12177 std::swap(LHS, RHS);
12179 case ICmpInst::ICMP_ULE:
12181 // min(A, ...) <= A
12182 // FIXME: what about umin_seq?
12183 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12184 // A <= max(A, ...)
12185 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12188 llvm_unreachable("covered switch fell through?!");
12191 bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12192 const SCEV *LHS, const SCEV *RHS,
12193 const SCEV *FoundLHS,
12194 const SCEV *FoundRHS,
12196 assert(getTypeSizeInBits(LHS->getType()) ==
12197 getTypeSizeInBits(RHS->getType()) &&
12198 "LHS and RHS have different sizes?");
12199 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12200 getTypeSizeInBits(FoundRHS->getType()) &&
12201 "FoundLHS and FoundRHS have different sizes?");
12202 // We want to avoid hurting the compile time with analysis of too big trees.
12203 if (Depth > MaxSCEVOperationsImplicationDepth)
12206 // We only want to work with GT comparison so far.
12207 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12208 Pred = CmpInst::getSwappedPredicate(Pred);
12209 std::swap(LHS, RHS);
12210 std::swap(FoundLHS, FoundRHS);
12213 // For unsigned, try to reduce it to corresponding signed comparison.
12214 if (Pred == ICmpInst::ICMP_UGT)
12215 // We can replace unsigned predicate with its signed counterpart if all
12216 // involved values are non-negative.
12217 // TODO: We could have better support for unsigned.
12218 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12219 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12220 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12221 // use this fact to prove that LHS and RHS are non-negative.
12222 const SCEV *MinusOne = getMinusOne(LHS->getType());
12223 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12225 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12227 Pred = ICmpInst::ICMP_SGT;
12230 if (Pred != ICmpInst::ICMP_SGT)
12233 auto GetOpFromSExt = [&](const SCEV *S) {
12234 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12235 return Ext->getOperand();
12236 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12237 // the constant in some cases.
12241 // Acquire values from extensions.
12242 auto *OrigLHS = LHS;
12243 auto *OrigFoundLHS = FoundLHS;
12244 LHS = GetOpFromSExt(LHS);
12245 FoundLHS = GetOpFromSExt(FoundLHS);
12247 // Is the SGT predicate can be proved trivially or using the found context.
12248 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12249 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12250 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12251 FoundRHS, Depth + 1);
12254 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12255 // We want to avoid creation of any new non-constant SCEV. Since we are
12256 // going to compare the operands to RHS, we should be certain that we don't
12257 // need any size extensions for this. So let's decline all cases when the
12258 // sizes of types of LHS and RHS do not match.
12259 // TODO: Maybe try to get RHS from sext to catch more cases?
12260 if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType()))
12263 // Should not overflow.
12264 if (!LHSAddExpr->hasNoSignedWrap())
12267 auto *LL = LHSAddExpr->getOperand(0);
12268 auto *LR = LHSAddExpr->getOperand(1);
12269 auto *MinusOne = getMinusOne(RHS->getType());
12271 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12272 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12273 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12275 // Try to prove the following rule:
12276 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12277 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12278 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12280 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12282 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12284 using namespace llvm::PatternMatch;
12286 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12287 // Rules for division.
12288 // We are going to perform some comparisons with Denominator and its
12289 // derivative expressions. In general case, creating a SCEV for it may
12290 // lead to a complex analysis of the entire graph, and in particular it
12291 // can request trip count recalculation for the same loop. This would
12292 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12293 // this, we only want to create SCEVs that are constants in this section.
12294 // So we bail if Denominator is not a constant.
12295 if (!isa<ConstantInt>(LR))
12298 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12300 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12301 // then a SCEV for the numerator already exists and matches with FoundLHS.
12302 auto *Numerator = getExistingSCEV(LL);
12303 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12306 // Make sure that the numerator matches with FoundLHS and the denominator
12308 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12311 auto *DTy = Denominator->getType();
12312 auto *FRHSTy = FoundRHS->getType();
12313 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12314 // One of types is a pointer and another one is not. We cannot extend
12315 // them properly to a wider type, so let us just reject this case.
12316 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12317 // to avoid this check.
12321 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12322 auto *WTy = getWiderType(DTy, FRHSTy);
12323 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12324 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12326 // Try to prove the following rule:
12327 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12328 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12329 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12330 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12331 if (isKnownNonPositive(RHS) &&
12332 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12335 // Try to prove the following rule:
12336 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12337 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12338 // If we divide it by Denominator > 2, then:
12339 // 1. If FoundLHS is negative, then the result is 0.
12340 // 2. If FoundLHS is non-negative, then the result is non-negative.
12341 // Anyways, the result is non-negative.
12342 auto *MinusOne = getMinusOne(WTy);
12343 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12344 if (isKnownNegative(RHS) &&
12345 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12350 // If our expression contained SCEVUnknown Phis, and we split it down and now
12351 // need to prove something for them, try to prove the predicate for every
12352 // possible incoming values of those Phis.
12353 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12359 static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
12360 const SCEV *LHS, const SCEV *RHS) {
12361 // zext x u<= sext x, sext x s<= zext x
12363 case ICmpInst::ICMP_SGE:
12364 std::swap(LHS, RHS);
12366 case ICmpInst::ICMP_SLE: {
12367 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12368 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12369 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12370 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12374 case ICmpInst::ICMP_UGE:
12375 std::swap(LHS, RHS);
12377 case ICmpInst::ICMP_ULE: {
12378 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12379 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12380 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12381 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12392 ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12393 const SCEV *LHS, const SCEV *RHS) {
12394 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12395 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12396 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12397 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12398 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12402 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12403 const SCEV *LHS, const SCEV *RHS,
12404 const SCEV *FoundLHS,
12405 const SCEV *FoundRHS) {
12407 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12408 case ICmpInst::ICMP_EQ:
12409 case ICmpInst::ICMP_NE:
12410 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12413 case ICmpInst::ICMP_SLT:
12414 case ICmpInst::ICMP_SLE:
12415 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12416 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12419 case ICmpInst::ICMP_SGT:
12420 case ICmpInst::ICMP_SGE:
12421 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12422 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12425 case ICmpInst::ICMP_ULT:
12426 case ICmpInst::ICMP_ULE:
12427 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12428 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12431 case ICmpInst::ICMP_UGT:
12432 case ICmpInst::ICMP_UGE:
12433 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12434 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12439 // Maybe it can be proved via operations?
12440 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12446 bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12449 const SCEV *FoundLHS,
12450 const SCEV *FoundRHS) {
12451 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12452 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12453 // reduce the compile time impact of this optimization.
12456 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12460 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12462 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12463 // antecedent "`FoundLHS` `Pred` `FoundRHS`".
12464 ConstantRange FoundLHSRange =
12465 ConstantRange::makeExactICmpRegion(Pred, ConstFoundRHS);
12467 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12468 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12470 // We can also compute the range of values for `LHS` that satisfy the
12471 // consequent, "`LHS` `Pred` `RHS`":
12472 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12473 // The antecedent implies the consequent if every value of `LHS` that
12474 // satisfies the antecedent also satisfies the consequent.
12475 return LHSRange.icmp(Pred, ConstRHS);
12478 bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12480 assert(isKnownPositive(Stride) && "Positive stride expected!");
12482 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12483 const SCEV *One = getOne(Stride->getType());
12486 APInt MaxRHS = getSignedRangeMax(RHS);
12487 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12488 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12490 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12491 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12494 APInt MaxRHS = getUnsignedRangeMax(RHS);
12495 APInt MaxValue = APInt::getMaxValue(BitWidth);
12496 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12498 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12499 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12502 bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12505 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12506 const SCEV *One = getOne(Stride->getType());
12509 APInt MinRHS = getSignedRangeMin(RHS);
12510 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12511 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12513 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12514 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12517 APInt MinRHS = getUnsignedRangeMin(RHS);
12518 APInt MinValue = APInt::getMinValue(BitWidth);
12519 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12521 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12522 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12525 const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
12526 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12527 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12528 // expression fixes the case of N=0.
12529 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12530 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12531 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12534 const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12535 const SCEV *Stride,
12539 // The logic in this function assumes we can represent a positive stride.
12540 // If we can't, the backedge-taken count must be zero.
12541 if (IsSigned && BitWidth == 1)
12542 return getZero(Stride->getType());
12544 // This code below only been closely audited for negative strides in the
12545 // unsigned comparison case, it may be correct for signed comparison, but
12546 // that needs to be established.
12547 if (IsSigned && isKnownNegative(Stride))
12548 return getCouldNotCompute();
12550 // Calculate the maximum backedge count based on the range of values
12551 // permitted by Start, End, and Stride.
12553 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12556 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12558 // We assume either the stride is positive, or the backedge-taken count
12559 // is zero. So force StrideForMaxBECount to be at least one.
12560 APInt One(BitWidth, 1);
12561 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12562 : APIntOps::umax(One, MinStride);
12564 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12565 : APInt::getMaxValue(BitWidth);
12566 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12568 // Although End can be a MAX expression we estimate MaxEnd considering only
12569 // the case End = RHS of the loop termination condition. This is safe because
12570 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12572 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12573 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12575 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12576 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12577 : APIntOps::umax(MaxEnd, MinStart);
12579 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12580 getConstant(StrideForMaxBECount) /* Step */);
12583 ScalarEvolution::ExitLimit
12584 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12585 const Loop *L, bool IsSigned,
12586 bool ControlsOnlyExit, bool AllowPredicates) {
12587 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
12589 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12590 bool PredicatedIV = false;
12592 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12593 // Can we prove this loop *must* be UB if overflow of IV occurs?
12594 // Reasoning goes as follows:
12595 // * Suppose the IV did self wrap.
12596 // * If Stride evenly divides the iteration space, then once wrap
12597 // occurs, the loop must revisit the same values.
12598 // * We know that RHS is invariant, and that none of those values
12599 // caused this exit to be taken previously. Thus, this exit is
12600 // dynamically dead.
12601 // * If this is the sole exit, then a dead exit implies the loop
12602 // must be infinite if there are no abnormal exits.
12603 // * If the loop were infinite, then it must either not be mustprogress
12604 // or have side effects. Otherwise, it must be UB.
12605 // * It can't (by assumption), be UB so we have contradicted our
12606 // premise and can conclude the IV did not in fact self-wrap.
12607 if (!isLoopInvariant(RHS, L))
12610 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12611 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12614 if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
12617 return loopIsFiniteByAssumption(L);
12621 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12622 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12623 if (AR && AR->getLoop() == L && AR->isAffine()) {
12624 auto canProveNUW = [&]() {
12625 if (!isLoopInvariant(RHS, L))
12628 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12629 // We need the sequence defined by AR to strictly increase in the
12630 // unsigned integer domain for the logic below to hold.
12633 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12634 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12635 // If RHS <=u Limit, then there must exist a value V in the sequence
12636 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12637 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12638 // overflow occurs. This limit also implies that a signed comparison
12639 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12640 // the high bits on both sides must be zero.
12641 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12642 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12643 Limit = Limit.zext(OuterBitWidth);
12644 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12646 auto Flags = AR->getNoWrapFlags();
12647 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12648 Flags = setFlags(Flags, SCEV::FlagNUW);
12650 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12651 if (AR->hasNoUnsignedWrap()) {
12652 // Emulate what getZeroExtendExpr would have done during construction
12653 // if we'd been able to infer the fact just above at that time.
12654 const SCEV *Step = AR->getStepRecurrence(*this);
12655 Type *Ty = ZExt->getType();
12656 auto *S = getAddRecExpr(
12657 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12658 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12659 IV = dyn_cast<SCEVAddRecExpr>(S);
12666 if (!IV && AllowPredicates) {
12667 // Try to make this an AddRec using runtime tests, in the first X
12668 // iterations of this loop, where X is the SCEV expression found by the
12669 // algorithm below.
12670 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12671 PredicatedIV = true;
12674 // Avoid weird loops
12675 if (!IV || IV->getLoop() != L || !IV->isAffine())
12676 return getCouldNotCompute();
12678 // A precondition of this method is that the condition being analyzed
12679 // reaches an exiting branch which dominates the latch. Given that, we can
12680 // assume that an increment which violates the nowrap specification and
12681 // produces poison must cause undefined behavior when the resulting poison
12682 // value is branched upon and thus we can conclude that the backedge is
12683 // taken no more often than would be required to produce that poison value.
12684 // Note that a well defined loop can exit on the iteration which violates
12685 // the nowrap specification if there is another exit (either explicit or
12686 // implicit/exceptional) which causes the loop to execute before the
12687 // exiting instruction we're analyzing would trigger UB.
12688 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12689 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
12690 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
12692 const SCEV *Stride = IV->getStepRecurrence(*this);
12694 bool PositiveStride = isKnownPositive(Stride);
12696 // Avoid negative or zero stride values.
12697 if (!PositiveStride) {
12698 // We can compute the correct backedge taken count for loops with unknown
12699 // strides if we can prove that the loop is not an infinite loop with side
12700 // effects. Here's the loop structure we are trying to handle -
12706 // } while (i < end);
12708 // The backedge taken count for such loops is evaluated as -
12709 // (max(end, start + stride) - start - 1) /u stride
12711 // The additional preconditions that we need to check to prove correctness
12712 // of the above formula is as follows -
12714 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12716 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12717 // no side effects within the loop)
12718 // c) loop has a single static exit (with no abnormal exits)
12720 // Precondition a) implies that if the stride is negative, this is a single
12721 // trip loop. The backedge taken count formula reduces to zero in this case.
12723 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12724 // then a zero stride means the backedge can't be taken without executing
12725 // undefined behavior.
12727 // The positive stride case is the same as isKnownPositive(Stride) returning
12728 // true (original behavior of the function).
12730 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12731 !loopHasNoAbnormalExits(L))
12732 return getCouldNotCompute();
12734 if (!isKnownNonZero(Stride)) {
12735 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12736 // if it might eventually be greater than start and if so, on which
12737 // iteration. We can't even produce a useful upper bound.
12738 if (!isLoopInvariant(RHS, L))
12739 return getCouldNotCompute();
12741 // We allow a potentially zero stride, but we need to divide by stride
12742 // below. Since the loop can't be infinite and this check must control
12743 // the sole exit, we can infer the exit must be taken on the first
12744 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12745 // we know the numerator in the divides below must be zero, so we can
12746 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12747 // and produce the right result.
12748 // FIXME: Handle the case where Stride is poison?
12749 auto wouldZeroStrideBeUB = [&]() {
12750 // Proof by contradiction. Suppose the stride were zero. If we can
12751 // prove that the backedge *is* taken on the first iteration, then since
12752 // we know this condition controls the sole exit, we must have an
12753 // infinite loop. We can't have a (well defined) infinite loop per
12754 // check just above.
12755 // Note: The (Start - Stride) term is used to get the start' term from
12756 // (start' + stride,+,stride). Remember that we only care about the
12757 // result of this expression when stride == 0 at runtime.
12758 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12759 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12761 if (!wouldZeroStrideBeUB()) {
12762 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12765 } else if (!Stride->isOne() && !NoWrap) {
12766 auto isUBOnWrap = [&]() {
12767 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12768 // follows trivially from the fact that every (un)signed-wrapped, but
12769 // not self-wrapped value must be LT than the last value before
12770 // (un)signed wrap. Since we know that last value didn't exit, nor
12771 // will any smaller one.
12772 return canAssumeNoSelfWrap(IV);
12775 // Avoid proven overflow cases: this will ensure that the backedge taken
12776 // count will not generate any unsigned overflow. Relaxed no-overflow
12777 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12778 // undefined behaviors like the case of C language.
12779 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12780 return getCouldNotCompute();
12783 // On all paths just preceeding, we established the following invariant:
12784 // IV can be assumed not to overflow up to and including the exiting
12785 // iteration. We proved this in one of two ways:
12786 // 1) We can show overflow doesn't occur before the exiting iteration
12787 // 1a) canIVOverflowOnLT, and b) step of one
12788 // 2) We can show that if overflow occurs, the loop must execute UB
12789 // before any possible exit.
12790 // Note that we have not yet proved RHS invariant (in general).
12792 const SCEV *Start = IV->getStart();
12794 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12795 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12796 // Use integer-typed versions for actual computation; we can't subtract
12797 // pointers in general.
12798 const SCEV *OrigStart = Start;
12799 const SCEV *OrigRHS = RHS;
12800 if (Start->getType()->isPointerTy()) {
12801 Start = getLosslessPtrToIntExpr(Start);
12802 if (isa<SCEVCouldNotCompute>(Start))
12805 if (RHS->getType()->isPointerTy()) {
12806 RHS = getLosslessPtrToIntExpr(RHS);
12807 if (isa<SCEVCouldNotCompute>(RHS))
12811 // When the RHS is not invariant, we do not know the end bound of the loop and
12812 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
12813 // calculate the MaxBECount, given the start, stride and max value for the end
12814 // bound of the loop (RHS), and the fact that IV does not overflow (which is
12816 if (!isLoopInvariant(RHS, L)) {
12817 const SCEV *MaxBECount = computeMaxBECountForLT(
12818 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12819 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
12820 MaxBECount, false /*MaxOrZero*/, Predicates);
12823 // We use the expression (max(End,Start)-Start)/Stride to describe the
12824 // backedge count, as if the backedge is taken at least once max(End,Start)
12825 // is End and so the result is as above, and if not max(End,Start) is Start
12826 // so we get a backedge count of zero.
12827 const SCEV *BECount = nullptr;
12828 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
12829 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
12830 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
12831 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
12832 // Can we prove (max(RHS,Start) > Start - Stride?
12833 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
12834 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
12835 // In this case, we can use a refined formula for computing backedge taken
12836 // count. The general formula remains:
12837 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
12838 // We want to use the alternate formula:
12839 // "((End - 1) - (Start - Stride)) /u Stride"
12840 // Let's do a quick case analysis to show these are equivalent under
12841 // our precondition that max(RHS,Start) > Start - Stride.
12842 // * For RHS <= Start, the backedge-taken count must be zero.
12843 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12844 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
12845 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
12846 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
12847 // this to the stride of 1 case.
12848 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
12849 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12850 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
12851 // "((RHS - (Start - Stride) - 1) /u Stride".
12852 // Our preconditions trivially imply no overflow in that form.
12853 const SCEV *MinusOne = getMinusOne(Stride->getType());
12854 const SCEV *Numerator =
12855 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
12856 BECount = getUDivExpr(Numerator, Stride);
12859 const SCEV *BECountIfBackedgeTaken = nullptr;
12861 auto canProveRHSGreaterThanEqualStart = [&]() {
12862 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
12863 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart))
12866 // (RHS > Start - 1) implies RHS >= Start.
12867 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
12868 // "Start - 1" doesn't overflow.
12869 // * For signed comparison, if Start - 1 does overflow, it's equal
12870 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
12871 // * For unsigned comparison, if Start - 1 does overflow, it's equal
12872 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
12874 // FIXME: Should isLoopEntryGuardedByCond do this for us?
12875 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
12876 auto *StartMinusOne = getAddExpr(OrigStart,
12877 getMinusOne(OrigStart->getType()));
12878 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
12881 // If we know that RHS >= Start in the context of loop, then we know that
12882 // max(RHS, Start) = RHS at this point.
12884 if (canProveRHSGreaterThanEqualStart()) {
12887 // If RHS < Start, the backedge will be taken zero times. So in
12888 // general, we can write the backedge-taken count as:
12890 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
12892 // We convert it to the following to make it more convenient for SCEV:
12894 // ceil(max(RHS, Start) - Start) / Stride
12895 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
12897 // See what would happen if we assume the backedge is taken. This is
12898 // used to compute MaxBECount.
12899 BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
12902 // At this point, we know:
12904 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
12905 // 2. The index variable doesn't overflow.
12907 // Therefore, we know N exists such that
12908 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
12909 // doesn't overflow.
12911 // Using this information, try to prove whether the addition in
12912 // "(Start - End) + (Stride - 1)" has unsigned overflow.
12913 const SCEV *One = getOne(Stride->getType());
12914 bool MayAddOverflow = [&] {
12915 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
12916 if (StrideC->getAPInt().isPowerOf2()) {
12917 // Suppose Stride is a power of two, and Start/End are unsigned
12918 // integers. Let UMAX be the largest representable unsigned
12921 // By the preconditions of this function, we know
12922 // "(Start + Stride * N) >= End", and this doesn't overflow.
12925 // End <= (Start + Stride * N) <= UMAX
12927 // Subtracting Start from all the terms:
12929 // End - Start <= Stride * N <= UMAX - Start
12931 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
12933 // End - Start <= Stride * N <= UMAX
12935 // Stride * N is a multiple of Stride. Therefore,
12937 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
12939 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
12940 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
12942 // End - Start <= Stride * N <= UMAX - Stride - 1
12944 // Dropping the middle term:
12946 // End - Start <= UMAX - Stride - 1
12948 // Adding Stride - 1 to both sides:
12950 // (End - Start) + (Stride - 1) <= UMAX
12952 // In other words, the addition doesn't have unsigned overflow.
12954 // A similar proof works if we treat Start/End as signed values.
12955 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
12956 // use signed max instead of unsigned max. Note that we're trying
12957 // to prove a lack of unsigned overflow in either case.
12961 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
12962 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
12963 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
12964 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
12966 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
12972 const SCEV *Delta = getMinusSCEV(End, Start);
12973 if (!MayAddOverflow) {
12974 // floor((D + (S - 1)) / S)
12975 // We prefer this formulation if it's legal because it's fewer operations.
12977 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
12979 BECount = getUDivCeilSCEV(Delta, Stride);
12983 const SCEV *ConstantMaxBECount;
12984 bool MaxOrZero = false;
12985 if (isa<SCEVConstant>(BECount)) {
12986 ConstantMaxBECount = BECount;
12987 } else if (BECountIfBackedgeTaken &&
12988 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
12989 // If we know exactly how many times the backedge will be taken if it's
12990 // taken at least once, then the backedge count will either be that or
12992 ConstantMaxBECount = BECountIfBackedgeTaken;
12995 ConstantMaxBECount = computeMaxBECountForLT(
12996 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12999 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13000 !isa<SCEVCouldNotCompute>(BECount))
13001 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13003 const SCEV *SymbolicMaxBECount =
13004 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13005 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13009 ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13010 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13011 bool ControlsOnlyExit, bool AllowPredicates) {
13012 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
13013 // We handle only IV > Invariant
13014 if (!isLoopInvariant(RHS, L))
13015 return getCouldNotCompute();
13017 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13018 if (!IV && AllowPredicates)
13019 // Try to make this an AddRec using runtime tests, in the first X
13020 // iterations of this loop, where X is the SCEV expression found by the
13021 // algorithm below.
13022 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13024 // Avoid weird loops
13025 if (!IV || IV->getLoop() != L || !IV->isAffine())
13026 return getCouldNotCompute();
13028 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13029 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13030 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13032 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13034 // Avoid negative or zero stride values
13035 if (!isKnownPositive(Stride))
13036 return getCouldNotCompute();
13038 // Avoid proven overflow cases: this will ensure that the backedge taken count
13039 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13040 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13041 // behaviors like the case of C language.
13042 if (!Stride->isOne() && !NoWrap)
13043 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13044 return getCouldNotCompute();
13046 const SCEV *Start = IV->getStart();
13047 const SCEV *End = RHS;
13048 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13049 // If we know that Start >= RHS in the context of loop, then we know that
13050 // min(RHS, Start) = RHS at this point.
13051 if (isLoopEntryGuardedByCond(
13052 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13055 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13058 if (Start->getType()->isPointerTy()) {
13059 Start = getLosslessPtrToIntExpr(Start);
13060 if (isa<SCEVCouldNotCompute>(Start))
13063 if (End->getType()->isPointerTy()) {
13064 End = getLosslessPtrToIntExpr(End);
13065 if (isa<SCEVCouldNotCompute>(End))
13069 // Compute ((Start - End) + (Stride - 1)) / Stride.
13070 // FIXME: This can overflow. Holding off on fixing this for now;
13071 // howManyGreaterThans will hopefully be gone soon.
13072 const SCEV *One = getOne(Stride->getType());
13073 const SCEV *BECount = getUDivExpr(
13074 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13076 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13077 : getUnsignedRangeMax(Start);
13079 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13080 : getUnsignedRangeMin(Stride);
13082 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13083 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13084 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13086 // Although End can be a MIN expression we estimate MinEnd considering only
13087 // the case End = RHS. This is safe because in the other case (Start - End)
13088 // is zero, leading to a zero maximum backedge taken count.
13090 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13091 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13093 const SCEV *ConstantMaxBECount =
13094 isa<SCEVConstant>(BECount)
13096 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13097 getConstant(MinStride));
13099 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13100 ConstantMaxBECount = BECount;
13101 const SCEV *SymbolicMaxBECount =
13102 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13104 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13108 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
13109 ScalarEvolution &SE) const {
13110 if (Range.isFullSet()) // Infinite loop.
13111 return SE.getCouldNotCompute();
13113 // If the start is a non-zero constant, shift the range to simplify things.
13114 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13115 if (!SC->getValue()->isZero()) {
13116 SmallVector<const SCEV *, 4> Operands(operands());
13117 Operands[0] = SE.getZero(SC->getType());
13118 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13119 getNoWrapFlags(FlagNW));
13120 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13121 return ShiftedAddRec->getNumIterationsInRange(
13122 Range.subtract(SC->getAPInt()), SE);
13123 // This is strange and shouldn't happen.
13124 return SE.getCouldNotCompute();
13127 // The only time we can solve this is when we have all constant indices.
13128 // Otherwise, we cannot determine the overflow conditions.
13129 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13130 return SE.getCouldNotCompute();
13132 // Okay at this point we know that all elements of the chrec are constants and
13133 // that the start element is zero.
13135 // First check to see if the range contains zero. If not, the first
13136 // iteration exits.
13137 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13138 if (!Range.contains(APInt(BitWidth, 0)))
13139 return SE.getZero(getType());
13142 // If this is an affine expression then we have this situation:
13143 // Solve {0,+,A} in Range === Ax in Range
13145 // We know that zero is in the range. If A is positive then we know that
13146 // the upper value of the range must be the first possible exit value.
13147 // If A is negative then the lower of the range is the last possible loop
13148 // value. Also note that we already checked for a full range.
13149 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13150 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13152 // The exit value should be (End+A)/A.
13153 APInt ExitVal = (End + A).udiv(A);
13154 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13156 // Evaluate at the exit value. If we really did fall out of the valid
13157 // range, then we computed our trip count, otherwise wrap around or other
13158 // things must have happened.
13159 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13160 if (Range.contains(Val->getValue()))
13161 return SE.getCouldNotCompute(); // Something strange happened
13163 // Ensure that the previous value is in the range.
13164 assert(Range.contains(
13165 EvaluateConstantChrecAtConstant(this,
13166 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13167 "Linear scev computation is off in a bad way!");
13168 return SE.getConstant(ExitValue);
13171 if (isQuadratic()) {
13172 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13173 return SE.getConstant(*S);
13176 return SE.getCouldNotCompute();
13179 const SCEVAddRecExpr *
13180 SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
13181 assert(getNumOperands() > 1 && "AddRec with zero step?");
13182 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13183 // but in this case we cannot guarantee that the value returned will be an
13184 // AddRec because SCEV does not have a fixed point where it stops
13185 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13186 // may happen if we reach arithmetic depth limit while simplifying. So we
13187 // construct the returned value explicitly.
13188 SmallVector<const SCEV *, 3> Ops;
13189 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13190 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13191 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13192 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13193 // We know that the last operand is not a constant zero (otherwise it would
13194 // have been popped out earlier). This guarantees us that if the result has
13195 // the same last operand, then it will also not be popped out, meaning that
13196 // the returned value will be an AddRec.
13197 const SCEV *Last = getOperand(getNumOperands() - 1);
13198 assert(!Last->isZero() && "Recurrency with zero step?");
13199 Ops.push_back(Last);
13200 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13201 SCEV::FlagAnyWrap));
13204 // Return true when S contains at least an undef value.
13205 bool ScalarEvolution::containsUndefs(const SCEV *S) const {
13206 return SCEVExprContains(S, [](const SCEV *S) {
13207 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13208 return isa<UndefValue>(SU->getValue());
13213 // Return true when S contains a value that is a nullptr.
13214 bool ScalarEvolution::containsErasedValue(const SCEV *S) const {
13215 return SCEVExprContains(S, [](const SCEV *S) {
13216 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13217 return SU->getValue() == nullptr;
13222 /// Return the size of an element read or written by Inst.
13223 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
13225 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13226 Ty = Store->getValueOperand()->getType();
13227 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13228 Ty = Load->getType();
13232 Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
13233 return getSizeOfExpr(ETy, Ty);
13236 //===----------------------------------------------------------------------===//
13237 // SCEVCallbackVH Class Implementation
13238 //===----------------------------------------------------------------------===//
13240 void ScalarEvolution::SCEVCallbackVH::deleted() {
13241 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13242 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13243 SE->ConstantEvolutionLoopExitValue.erase(PN);
13244 SE->eraseValueFromMap(getValPtr());
13245 // this now dangles!
13248 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13249 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13251 // Forget all the expressions associated with users of the old value,
13252 // so that future queries will recompute the expressions using the new
13254 Value *Old = getValPtr();
13255 SmallVector<User *, 16> Worklist(Old->users());
13256 SmallPtrSet<User *, 8> Visited;
13257 while (!Worklist.empty()) {
13258 User *U = Worklist.pop_back_val();
13259 // Deleting the Old value will cause this to dangle. Postpone
13260 // that until everything else is done.
13263 if (!Visited.insert(U).second)
13265 if (PHINode *PN = dyn_cast<PHINode>(U))
13266 SE->ConstantEvolutionLoopExitValue.erase(PN);
13267 SE->eraseValueFromMap(U);
13268 llvm::append_range(Worklist, U->users());
13270 // Delete the Old value.
13271 if (PHINode *PN = dyn_cast<PHINode>(Old))
13272 SE->ConstantEvolutionLoopExitValue.erase(PN);
13273 SE->eraseValueFromMap(Old);
13274 // this now dangles!
13277 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13278 : CallbackVH(V), SE(se) {}
13280 //===----------------------------------------------------------------------===//
13281 // ScalarEvolution Class Implementation
13282 //===----------------------------------------------------------------------===//
13284 ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
13285 AssumptionCache &AC, DominatorTree &DT,
13287 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
13288 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13289 LoopDispositions(64), BlockDispositions(64) {
13290 // To use guards for proving predicates, we need to scan every instruction in
13291 // relevant basic blocks, and not just terminators. Doing this is a waste of
13292 // time if the IR does not actually contain any calls to
13293 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13295 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13296 // to _add_ guards to the module when there weren't any before, and wants
13297 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13298 // efficient in lieu of being smart in that rather obscure case.
13300 auto *GuardDecl = F.getParent()->getFunction(
13301 Intrinsic::getName(Intrinsic::experimental_guard));
13302 HasGuards = GuardDecl && !GuardDecl->use_empty();
13305 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
13306 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
13307 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13308 ValueExprMap(std::move(Arg.ValueExprMap)),
13309 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13310 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13311 PendingMerges(std::move(Arg.PendingMerges)),
13312 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13313 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13314 PredicatedBackedgeTakenCounts(
13315 std::move(Arg.PredicatedBackedgeTakenCounts)),
13316 BECountUsers(std::move(Arg.BECountUsers)),
13317 ConstantEvolutionLoopExitValue(
13318 std::move(Arg.ConstantEvolutionLoopExitValue)),
13319 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13320 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13321 LoopDispositions(std::move(Arg.LoopDispositions)),
13322 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13323 BlockDispositions(std::move(Arg.BlockDispositions)),
13324 SCEVUsers(std::move(Arg.SCEVUsers)),
13325 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13326 SignedRanges(std::move(Arg.SignedRanges)),
13327 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13328 UniquePreds(std::move(Arg.UniquePreds)),
13329 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13330 LoopUsers(std::move(Arg.LoopUsers)),
13331 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13332 FirstUnknown(Arg.FirstUnknown) {
13333 Arg.FirstUnknown = nullptr;
13336 ScalarEvolution::~ScalarEvolution() {
13337 // Iterate through all the SCEVUnknown instances and call their
13338 // destructors, so that they release their references to their values.
13339 for (SCEVUnknown *U = FirstUnknown; U;) {
13340 SCEVUnknown *Tmp = U;
13342 Tmp->~SCEVUnknown();
13344 FirstUnknown = nullptr;
13346 ExprValueMap.clear();
13347 ValueExprMap.clear();
13349 BackedgeTakenCounts.clear();
13350 PredicatedBackedgeTakenCounts.clear();
13352 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13353 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13354 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13355 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13356 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13359 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
13360 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13363 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
13365 // Print all inner loops first
13367 PrintLoopInfo(OS, SE, I);
13370 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13373 SmallVector<BasicBlock *, 8> ExitingBlocks;
13374 L->getExitingBlocks(ExitingBlocks);
13375 if (ExitingBlocks.size() != 1)
13376 OS << "<multiple exits> ";
13378 if (SE->hasLoopInvariantBackedgeTakenCount(L))
13379 OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n";
13381 OS << "Unpredictable backedge-taken count.\n";
13383 if (ExitingBlocks.size() > 1)
13384 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13385 OS << " exit count for " << ExitingBlock->getName() << ": "
13386 << *SE->getExitCount(L, ExitingBlock) << "\n";
13390 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13393 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13394 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13395 OS << "constant max backedge-taken count is " << *ConstantBTC;
13396 if (SE->isBackedgeTakenCountMaxOrZero(L))
13397 OS << ", actual taken count either this or zero.";
13399 OS << "Unpredictable constant max backedge-taken count. ";
13404 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13407 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13408 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13409 OS << "symbolic max backedge-taken count is " << *SymbolicBTC;
13410 if (SE->isBackedgeTakenCountMaxOrZero(L))
13411 OS << ", actual taken count either this or zero.";
13413 OS << "Unpredictable symbolic max backedge-taken count. ";
13417 if (ExitingBlocks.size() > 1)
13418 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13419 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": "
13420 << *SE->getExitCount(L, ExitingBlock, ScalarEvolution::SymbolicMaximum)
13425 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13428 SmallVector<const SCEVPredicate *, 4> Preds;
13429 auto PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13430 if (!isa<SCEVCouldNotCompute>(PBT)) {
13431 OS << "Predicated backedge-taken count is " << *PBT << "\n";
13432 OS << " Predicates:\n";
13433 for (const auto *P : Preds)
13436 OS << "Unpredictable predicated backedge-taken count. ";
13440 if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
13442 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13444 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13449 raw_ostream &operator<<(raw_ostream &OS, ScalarEvolution::LoopDisposition LD) {
13451 case ScalarEvolution::LoopVariant:
13454 case ScalarEvolution::LoopInvariant:
13457 case ScalarEvolution::LoopComputable:
13458 OS << "Computable";
13464 raw_ostream &operator<<(raw_ostream &OS, ScalarEvolution::BlockDisposition BD) {
13466 case ScalarEvolution::DoesNotDominateBlock:
13467 OS << "DoesNotDominate";
13469 case ScalarEvolution::DominatesBlock:
13472 case ScalarEvolution::ProperlyDominatesBlock:
13473 OS << "ProperlyDominates";
13480 void ScalarEvolution::print(raw_ostream &OS) const {
13481 // ScalarEvolution's implementation of the print method is to print
13482 // out SCEV values of all instructions that are interesting. Doing
13483 // this potentially causes it to create new SCEV objects though,
13484 // which technically conflicts with the const qualifier. This isn't
13485 // observable from outside the class though, so casting away the
13486 // const isn't dangerous.
13487 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13489 if (ClassifyExpressions) {
13490 OS << "Classifying expressions for: ";
13491 F.printAsOperand(OS, /*PrintType=*/false);
13493 for (Instruction &I : instructions(F))
13494 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13497 const SCEV *SV = SE.getSCEV(&I);
13499 if (!isa<SCEVCouldNotCompute>(SV)) {
13501 SE.getUnsignedRange(SV).print(OS);
13503 SE.getSignedRange(SV).print(OS);
13506 const Loop *L = LI.getLoopFor(I.getParent());
13508 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13512 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13514 SE.getUnsignedRange(AtUse).print(OS);
13516 SE.getSignedRange(AtUse).print(OS);
13521 OS << "\t\t" "Exits: ";
13522 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13523 if (!SE.isLoopInvariant(ExitValue, L)) {
13524 OS << "<<Unknown>>";
13530 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13532 OS << "\t\t" "LoopDispositions: { ";
13538 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13539 OS << ": " << SE.getLoopDisposition(SV, Iter);
13542 for (const auto *InnerL : depth_first(L)) {
13546 OS << "\t\t" "LoopDispositions: { ";
13552 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13553 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13563 OS << "Determining loop execution counts for: ";
13564 F.printAsOperand(OS, /*PrintType=*/false);
13567 PrintLoopInfo(OS, &SE, I);
13570 ScalarEvolution::LoopDisposition
13571 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
13572 auto &Values = LoopDispositions[S];
13573 for (auto &V : Values) {
13574 if (V.getPointer() == L)
13577 Values.emplace_back(L, LoopVariant);
13578 LoopDisposition D = computeLoopDisposition(S, L);
13579 auto &Values2 = LoopDispositions[S];
13580 for (auto &V : llvm::reverse(Values2)) {
13581 if (V.getPointer() == L) {
13589 ScalarEvolution::LoopDisposition
13590 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13591 switch (S->getSCEVType()) {
13594 return LoopInvariant;
13595 case scAddRecExpr: {
13596 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13598 // If L is the addrec's loop, it's computable.
13599 if (AR->getLoop() == L)
13600 return LoopComputable;
13602 // Add recurrences are never invariant in the function-body (null loop).
13604 return LoopVariant;
13606 // Everything that is not defined at loop entry is variant.
13607 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13608 return LoopVariant;
13609 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13610 " dominate the contained loop's header?");
13612 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13613 if (AR->getLoop()->contains(L))
13614 return LoopInvariant;
13616 // This recurrence is variant w.r.t. L if any of its operands
13618 for (const auto *Op : AR->operands())
13619 if (!isLoopInvariant(Op, L))
13620 return LoopVariant;
13622 // Otherwise it's loop-invariant.
13623 return LoopInvariant;
13636 case scSequentialUMinExpr: {
13637 bool HasVarying = false;
13638 for (const auto *Op : S->operands()) {
13639 LoopDisposition D = getLoopDisposition(Op, L);
13640 if (D == LoopVariant)
13641 return LoopVariant;
13642 if (D == LoopComputable)
13645 return HasVarying ? LoopComputable : LoopInvariant;
13648 // All non-instruction values are loop invariant. All instructions are loop
13649 // invariant if they are not contained in the specified loop.
13650 // Instructions are never considered invariant in the function body
13651 // (null loop) because they are defined within the "loop".
13652 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13653 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13654 return LoopInvariant;
13655 case scCouldNotCompute:
13656 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13658 llvm_unreachable("Unknown SCEV kind!");
13661 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
13662 return getLoopDisposition(S, L) == LoopInvariant;
13665 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
13666 return getLoopDisposition(S, L) == LoopComputable;
13669 ScalarEvolution::BlockDisposition
13670 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13671 auto &Values = BlockDispositions[S];
13672 for (auto &V : Values) {
13673 if (V.getPointer() == BB)
13676 Values.emplace_back(BB, DoesNotDominateBlock);
13677 BlockDisposition D = computeBlockDisposition(S, BB);
13678 auto &Values2 = BlockDispositions[S];
13679 for (auto &V : llvm::reverse(Values2)) {
13680 if (V.getPointer() == BB) {
13688 ScalarEvolution::BlockDisposition
13689 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13690 switch (S->getSCEVType()) {
13693 return ProperlyDominatesBlock;
13694 case scAddRecExpr: {
13695 // This uses a "dominates" query instead of "properly dominates" query
13696 // to test for proper dominance too, because the instruction which
13697 // produces the addrec's value is a PHI, and a PHI effectively properly
13698 // dominates its entire containing block.
13699 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13700 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13701 return DoesNotDominateBlock;
13703 // Fall through into SCEVNAryExpr handling.
13717 case scSequentialUMinExpr: {
13718 bool Proper = true;
13719 for (const SCEV *NAryOp : S->operands()) {
13720 BlockDisposition D = getBlockDisposition(NAryOp, BB);
13721 if (D == DoesNotDominateBlock)
13722 return DoesNotDominateBlock;
13723 if (D == DominatesBlock)
13726 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13729 if (Instruction *I =
13730 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13731 if (I->getParent() == BB)
13732 return DominatesBlock;
13733 if (DT.properlyDominates(I->getParent(), BB))
13734 return ProperlyDominatesBlock;
13735 return DoesNotDominateBlock;
13737 return ProperlyDominatesBlock;
13738 case scCouldNotCompute:
13739 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13741 llvm_unreachable("Unknown SCEV kind!");
13744 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
13745 return getBlockDisposition(S, BB) >= DominatesBlock;
13748 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
13749 return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
13752 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
13753 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
13756 void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
13759 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13760 auto It = BECounts.find(L);
13761 if (It != BECounts.end()) {
13762 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
13763 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
13764 if (!isa<SCEVConstant>(S)) {
13765 auto UserIt = BECountUsers.find(S);
13766 assert(UserIt != BECountUsers.end());
13767 UserIt->second.erase({L, Predicated});
13771 BECounts.erase(It);
13775 void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
13776 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
13777 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
13779 while (!Worklist.empty()) {
13780 const SCEV *Curr = Worklist.pop_back_val();
13781 auto Users = SCEVUsers.find(Curr);
13782 if (Users != SCEVUsers.end())
13783 for (const auto *User : Users->second)
13784 if (ToForget.insert(User).second)
13785 Worklist.push_back(User);
13788 for (const auto *S : ToForget)
13789 forgetMemoizedResultsImpl(S);
13791 for (auto I = PredicatedSCEVRewrites.begin();
13792 I != PredicatedSCEVRewrites.end();) {
13793 std::pair<const SCEV *, const Loop *> Entry = I->first;
13794 if (ToForget.count(Entry.first))
13795 PredicatedSCEVRewrites.erase(I++);
13801 void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
13802 LoopDispositions.erase(S);
13803 BlockDispositions.erase(S);
13804 UnsignedRanges.erase(S);
13805 SignedRanges.erase(S);
13806 HasRecMap.erase(S);
13807 ConstantMultipleCache.erase(S);
13809 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
13810 UnsignedWrapViaInductionTried.erase(AR);
13811 SignedWrapViaInductionTried.erase(AR);
13814 auto ExprIt = ExprValueMap.find(S);
13815 if (ExprIt != ExprValueMap.end()) {
13816 for (Value *V : ExprIt->second) {
13817 auto ValueIt = ValueExprMap.find_as(V);
13818 if (ValueIt != ValueExprMap.end())
13819 ValueExprMap.erase(ValueIt);
13821 ExprValueMap.erase(ExprIt);
13824 auto ScopeIt = ValuesAtScopes.find(S);
13825 if (ScopeIt != ValuesAtScopes.end()) {
13826 for (const auto &Pair : ScopeIt->second)
13827 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
13828 erase_value(ValuesAtScopesUsers[Pair.second],
13829 std::make_pair(Pair.first, S));
13830 ValuesAtScopes.erase(ScopeIt);
13833 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
13834 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
13835 for (const auto &Pair : ScopeUserIt->second)
13836 erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
13837 ValuesAtScopesUsers.erase(ScopeUserIt);
13840 auto BEUsersIt = BECountUsers.find(S);
13841 if (BEUsersIt != BECountUsers.end()) {
13842 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
13843 auto Copy = BEUsersIt->second;
13844 for (const auto &Pair : Copy)
13845 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
13846 BECountUsers.erase(BEUsersIt);
13849 auto FoldUser = FoldCacheUser.find(S);
13850 if (FoldUser != FoldCacheUser.end())
13851 for (auto &KV : FoldUser->second)
13852 FoldCache.erase(KV);
13853 FoldCacheUser.erase(S);
13857 ScalarEvolution::getUsedLoops(const SCEV *S,
13858 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
13859 struct FindUsedLoops {
13860 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
13861 : LoopsUsed(LoopsUsed) {}
13862 SmallPtrSetImpl<const Loop *> &LoopsUsed;
13863 bool follow(const SCEV *S) {
13864 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
13865 LoopsUsed.insert(AR->getLoop());
13869 bool isDone() const { return false; }
13872 FindUsedLoops F(LoopsUsed);
13873 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
13876 void ScalarEvolution::getReachableBlocks(
13877 SmallPtrSetImpl<BasicBlock *> &Reachable, Function &F) {
13878 SmallVector<BasicBlock *> Worklist;
13879 Worklist.push_back(&F.getEntryBlock());
13880 while (!Worklist.empty()) {
13881 BasicBlock *BB = Worklist.pop_back_val();
13882 if (!Reachable.insert(BB).second)
13886 BasicBlock *TrueBB, *FalseBB;
13887 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
13888 m_BasicBlock(FalseBB)))) {
13889 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
13890 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
13894 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
13895 const SCEV *L = getSCEV(Cmp->getOperand(0));
13896 const SCEV *R = getSCEV(Cmp->getOperand(1));
13897 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
13898 Worklist.push_back(TrueBB);
13901 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
13903 Worklist.push_back(FalseBB);
13909 append_range(Worklist, successors(BB));
13913 void ScalarEvolution::verify() const {
13914 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13915 ScalarEvolution SE2(F, TLI, AC, DT, LI);
13917 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
13919 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
13920 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
13921 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
13923 const SCEV *visitConstant(const SCEVConstant *Constant) {
13924 return SE.getConstant(Constant->getAPInt());
13927 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
13928 return SE.getUnknown(Expr->getValue());
13931 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
13932 return SE.getCouldNotCompute();
13936 SCEVMapper SCM(SE2);
13937 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
13938 SE2.getReachableBlocks(ReachableBlocks, F);
13940 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
13941 if (containsUndefs(Old) || containsUndefs(New)) {
13942 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
13943 // not propagate undef aggressively). This means we can (and do) fail
13944 // verification in cases where a transform makes a value go from "undef"
13945 // to "undef+1" (say). The transform is fine, since in both cases the
13946 // result is "undef", but SCEV thinks the value increased by 1.
13950 // Unless VerifySCEVStrict is set, we only compare constant deltas.
13951 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
13952 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
13958 while (!LoopStack.empty()) {
13959 auto *L = LoopStack.pop_back_val();
13960 llvm::append_range(LoopStack, *L);
13962 // Only verify BECounts in reachable loops. For an unreachable loop,
13963 // any BECount is legal.
13964 if (!ReachableBlocks.contains(L->getHeader()))
13967 // Only verify cached BECounts. Computing new BECounts may change the
13968 // results of subsequent SCEV uses.
13969 auto It = BackedgeTakenCounts.find(L);
13970 if (It == BackedgeTakenCounts.end())
13974 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
13975 auto *NewBECount = SE2.getBackedgeTakenCount(L);
13977 if (CurBECount == SE2.getCouldNotCompute() ||
13978 NewBECount == SE2.getCouldNotCompute()) {
13979 // NB! This situation is legal, but is very suspicious -- whatever pass
13980 // change the loop to make a trip count go from could not compute to
13981 // computable or vice-versa *should have* invalidated SCEV. However, we
13982 // choose not to assert here (for now) since we don't want false
13987 if (SE.getTypeSizeInBits(CurBECount->getType()) >
13988 SE.getTypeSizeInBits(NewBECount->getType()))
13989 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
13990 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
13991 SE.getTypeSizeInBits(NewBECount->getType()))
13992 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
13994 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
13995 if (Delta && !Delta->isZero()) {
13996 dbgs() << "Trip Count for " << *L << " Changed!\n";
13997 dbgs() << "Old: " << *CurBECount << "\n";
13998 dbgs() << "New: " << *NewBECount << "\n";
13999 dbgs() << "Delta: " << *Delta << "\n";
14004 // Collect all valid loops currently in LoopInfo.
14005 SmallPtrSet<Loop *, 32> ValidLoops;
14006 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14007 while (!Worklist.empty()) {
14008 Loop *L = Worklist.pop_back_val();
14009 if (ValidLoops.insert(L).second)
14010 Worklist.append(L->begin(), L->end());
14012 for (const auto &KV : ValueExprMap) {
14014 // Check for SCEV expressions referencing invalid/deleted loops.
14015 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14016 assert(ValidLoops.contains(AR->getLoop()) &&
14017 "AddRec references invalid loop");
14021 // Check that the value is also part of the reverse map.
14022 auto It = ExprValueMap.find(KV.second);
14023 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14024 dbgs() << "Value " << *KV.first
14025 << " is in ValueExprMap but not in ExprValueMap\n";
14029 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14030 if (!ReachableBlocks.contains(I->getParent()))
14032 const SCEV *OldSCEV = SCM.visit(KV.second);
14033 const SCEV *NewSCEV = SE2.getSCEV(I);
14034 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14035 if (Delta && !Delta->isZero()) {
14036 dbgs() << "SCEV for value " << *I << " changed!\n"
14037 << "Old: " << *OldSCEV << "\n"
14038 << "New: " << *NewSCEV << "\n"
14039 << "Delta: " << *Delta << "\n";
14045 for (const auto &KV : ExprValueMap) {
14046 for (Value *V : KV.second) {
14047 auto It = ValueExprMap.find_as(V);
14048 if (It == ValueExprMap.end()) {
14049 dbgs() << "Value " << *V
14050 << " is in ExprValueMap but not in ValueExprMap\n";
14053 if (It->second != KV.first) {
14054 dbgs() << "Value " << *V << " mapped to " << *It->second
14055 << " rather than " << *KV.first << "\n";
14061 // Verify integrity of SCEV users.
14062 for (const auto &S : UniqueSCEVs) {
14063 for (const auto *Op : S.operands()) {
14064 // We do not store dependencies of constants.
14065 if (isa<SCEVConstant>(Op))
14067 auto It = SCEVUsers.find(Op);
14068 if (It != SCEVUsers.end() && It->second.count(&S))
14070 dbgs() << "Use of operand " << *Op << " by user " << S
14071 << " is not being tracked!\n";
14076 // Verify integrity of ValuesAtScopes users.
14077 for (const auto &ValueAndVec : ValuesAtScopes) {
14078 const SCEV *Value = ValueAndVec.first;
14079 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14080 const Loop *L = LoopAndValueAtScope.first;
14081 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14082 if (!isa<SCEVConstant>(ValueAtScope)) {
14083 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14084 if (It != ValuesAtScopesUsers.end() &&
14085 is_contained(It->second, std::make_pair(L, Value)))
14087 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14088 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14094 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14095 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14096 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14097 const Loop *L = LoopAndValue.first;
14098 const SCEV *Value = LoopAndValue.second;
14099 assert(!isa<SCEVConstant>(Value));
14100 auto It = ValuesAtScopes.find(Value);
14101 if (It != ValuesAtScopes.end() &&
14102 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14104 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14105 << *ValueAtScope << " missing in ValuesAtScopes\n";
14110 // Verify integrity of BECountUsers.
14111 auto VerifyBECountUsers = [&](bool Predicated) {
14113 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14114 for (const auto &LoopAndBEInfo : BECounts) {
14115 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14116 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14117 if (!isa<SCEVConstant>(S)) {
14118 auto UserIt = BECountUsers.find(S);
14119 if (UserIt != BECountUsers.end() &&
14120 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14122 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14123 << " missing from BECountUsers\n";
14130 VerifyBECountUsers(/* Predicated */ false);
14131 VerifyBECountUsers(/* Predicated */ true);
14133 // Verify intergity of loop disposition cache.
14134 for (auto &[S, Values] : LoopDispositions) {
14135 for (auto [Loop, CachedDisposition] : Values) {
14136 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14137 if (CachedDisposition != RecomputedDisposition) {
14138 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14139 << " is incorrect: cached " << CachedDisposition << ", actual "
14140 << RecomputedDisposition << "\n";
14146 // Verify integrity of the block disposition cache.
14147 for (auto &[S, Values] : BlockDispositions) {
14148 for (auto [BB, CachedDisposition] : Values) {
14149 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14150 if (CachedDisposition != RecomputedDisposition) {
14151 dbgs() << "Cached disposition of " << *S << " for block %"
14152 << BB->getName() << " is incorrect: cached " << CachedDisposition
14153 << ", actual " << RecomputedDisposition << "\n";
14159 // Verify FoldCache/FoldCacheUser caches.
14160 for (auto [FoldID, Expr] : FoldCache) {
14161 auto I = FoldCacheUser.find(Expr);
14162 if (I == FoldCacheUser.end()) {
14163 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14167 if (!is_contained(I->second, FoldID)) {
14168 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14172 for (auto [Expr, IDs] : FoldCacheUser) {
14173 for (auto &FoldID : IDs) {
14174 auto I = FoldCache.find(FoldID);
14175 if (I == FoldCache.end()) {
14176 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14180 if (I->second != Expr) {
14181 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14182 << *I->second << " != " << *Expr << "!\n";
14188 // Verify that ConstantMultipleCache computations are correct. We check that
14189 // cached multiples and recomputed multiples are multiples of each other to
14190 // verify correctness. It is possible that a recomputed multiple is different
14191 // from the cached multiple due to strengthened no wrap flags or changes in
14192 // KnownBits computations.
14193 for (auto [S, Multiple] : ConstantMultipleCache) {
14194 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14195 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14196 Multiple.urem(RecomputedMultiple) != 0 &&
14197 RecomputedMultiple.urem(Multiple) != 0)) {
14198 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14199 << *S << " : Computed " << RecomputedMultiple
14200 << " but cache contains " << Multiple << "!\n";
14206 bool ScalarEvolution::invalidate(
14207 Function &F, const PreservedAnalyses &PA,
14208 FunctionAnalysisManager::Invalidator &Inv) {
14209 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14210 // of its dependencies is invalidated.
14211 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14212 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14213 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14214 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14215 Inv.invalidate<LoopAnalysis>(F, PA);
14218 AnalysisKey ScalarEvolutionAnalysis::Key;
14220 ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
14221 FunctionAnalysisManager &AM) {
14222 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14223 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14224 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14225 auto &LI = AM.getResult<LoopAnalysis>(F);
14226 return ScalarEvolution(F, TLI, AC, DT, LI);
14230 ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) {
14231 AM.getResult<ScalarEvolutionAnalysis>(F).verify();
14232 return PreservedAnalyses::all();
14236 ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
14237 // For compatibility with opt's -analyze feature under legacy pass manager
14238 // which was not ported to NPM. This keeps tests using
14239 // update_analyze_test_checks.py working.
14240 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14241 << F.getName() << "':\n";
14242 AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
14243 return PreservedAnalyses::all();
14246 INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",
14247 "Scalar Evolution Analysis", false, true)
14248 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
14249 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
14250 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
14251 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
14252 INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
14253 "Scalar Evolution Analysis", false, true)
14255 char ScalarEvolutionWrapperPass::ID = 0;
14257 ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
14258 initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
14261 bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
14262 SE.reset(new ScalarEvolution(
14263 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14264 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14265 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14266 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14270 void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
14272 void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
14276 void ScalarEvolutionWrapperPass::verifyAnalysis() const {
14283 void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
14284 AU.setPreservesAll();
14285 AU.addRequiredTransitive<AssumptionCacheTracker>();
14286 AU.addRequiredTransitive<LoopInfoWrapperPass>();
14287 AU.addRequiredTransitive<DominatorTreeWrapperPass>();
14288 AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
14291 const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
14293 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14296 const SCEVPredicate *
14297 ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred,
14298 const SCEV *LHS, const SCEV *RHS) {
14299 FoldingSetNodeID ID;
14300 assert(LHS->getType() == RHS->getType() &&
14301 "Type mismatch between LHS and RHS");
14302 // Unique this node based on the arguments
14303 ID.AddInteger(SCEVPredicate::P_Compare);
14304 ID.AddInteger(Pred);
14305 ID.AddPointer(LHS);
14306 ID.AddPointer(RHS);
14307 void *IP = nullptr;
14308 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14310 SCEVComparePredicate *Eq = new (SCEVAllocator)
14311 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14312 UniquePreds.InsertNode(Eq, IP);
14316 const SCEVPredicate *ScalarEvolution::getWrapPredicate(
14317 const SCEVAddRecExpr *AR,
14318 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
14319 FoldingSetNodeID ID;
14320 // Unique this node based on the arguments
14321 ID.AddInteger(SCEVPredicate::P_Wrap);
14323 ID.AddInteger(AddedFlags);
14324 void *IP = nullptr;
14325 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14327 auto *OF = new (SCEVAllocator)
14328 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14329 UniquePreds.InsertNode(OF, IP);
14335 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14338 /// Rewrites \p S in the context of a loop L and the SCEV predication
14339 /// infrastructure.
14341 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14342 /// equivalences present in \p Pred.
14344 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14345 /// \p NewPreds such that the result will be an AddRecExpr.
14346 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14347 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
14348 const SCEVPredicate *Pred) {
14349 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14350 return Rewriter.visit(S);
14353 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14355 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14356 for (const auto *Pred : U->getPredicates())
14357 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14358 if (IPred->getLHS() == Expr &&
14359 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14360 return IPred->getRHS();
14361 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14362 if (IPred->getLHS() == Expr &&
14363 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14364 return IPred->getRHS();
14367 return convertToAddRecWithPreds(Expr);
14370 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14371 const SCEV *Operand = visit(Expr->getOperand());
14372 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14373 if (AR && AR->getLoop() == L && AR->isAffine()) {
14374 // This couldn't be folded because the operand didn't have the nuw
14375 // flag. Add the nusw flag as an assumption that we could make.
14376 const SCEV *Step = AR->getStepRecurrence(SE);
14377 Type *Ty = Expr->getType();
14378 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14379 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14380 SE.getSignExtendExpr(Step, Ty), L,
14381 AR->getNoWrapFlags());
14383 return SE.getZeroExtendExpr(Operand, Expr->getType());
14386 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14387 const SCEV *Operand = visit(Expr->getOperand());
14388 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14389 if (AR && AR->getLoop() == L && AR->isAffine()) {
14390 // This couldn't be folded because the operand didn't have the nsw
14391 // flag. Add the nssw flag as an assumption that we could make.
14392 const SCEV *Step = AR->getStepRecurrence(SE);
14393 Type *Ty = Expr->getType();
14394 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14395 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14396 SE.getSignExtendExpr(Step, Ty), L,
14397 AR->getNoWrapFlags());
14399 return SE.getSignExtendExpr(Operand, Expr->getType());
14403 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14404 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
14405 const SCEVPredicate *Pred)
14406 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14408 bool addOverflowAssumption(const SCEVPredicate *P) {
14410 // Check if we've already made this assumption.
14411 return Pred && Pred->implies(P);
14413 NewPreds->insert(P);
14417 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14418 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
14419 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14420 return addOverflowAssumption(A);
14423 // If \p Expr represents a PHINode, we try to see if it can be represented
14424 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14425 // to add this predicate as a runtime overflow check, we return the AddRec.
14426 // If \p Expr does not meet these conditions (is not a PHI node, or we
14427 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14429 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14430 if (!isa<PHINode>(Expr->getValue()))
14433 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14434 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14435 if (!PredicatedRewrite)
14437 for (const auto *P : PredicatedRewrite->second){
14438 // Wrap predicates from outer loops are not supported.
14439 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14440 if (L != WP->getExpr()->getLoop())
14443 if (!addOverflowAssumption(P))
14446 return PredicatedRewrite->first;
14449 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
14450 const SCEVPredicate *Pred;
14454 } // end anonymous namespace
14457 ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
14458 const SCEVPredicate &Preds) {
14459 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14462 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
14463 const SCEV *S, const Loop *L,
14464 SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
14465 SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
14466 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14467 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14472 // Since the transformation was successful, we can now transfer the SCEV
14474 for (const auto *P : TransformPreds)
14480 /// SCEV predicates
14481 SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
14482 SCEVPredicateKind Kind)
14483 : FastID(ID), Kind(Kind) {}
14485 SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
14486 const ICmpInst::Predicate Pred,
14487 const SCEV *LHS, const SCEV *RHS)
14488 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14489 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14490 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14493 bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14494 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14499 if (Pred != ICmpInst::ICMP_EQ)
14502 return Op->LHS == LHS && Op->RHS == RHS;
14505 bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14507 void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const {
14508 if (Pred == ICmpInst::ICMP_EQ)
14509 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14511 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14516 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
14517 const SCEVAddRecExpr *AR,
14518 IncrementWrapFlags Flags)
14519 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14521 const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14523 bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14524 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14526 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14529 bool SCEVWrapPredicate::isAlwaysTrue() const {
14530 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14531 IncrementWrapFlags IFlags = Flags;
14533 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14534 IFlags = clearFlags(IFlags, IncrementNSSW);
14536 return IFlags == IncrementAnyWrap;
14539 void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
14540 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14541 if (SCEVWrapPredicate::IncrementNUSW & getFlags())
14543 if (SCEVWrapPredicate::IncrementNSSW & getFlags())
14548 SCEVWrapPredicate::IncrementWrapFlags
14549 SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
14550 ScalarEvolution &SE) {
14551 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14552 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14554 // We can safely transfer the NSW flag as NSSW.
14555 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14556 ImpliedFlags = IncrementNSSW;
14558 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14559 // If the increment is positive, the SCEV NUW flag will also imply the
14560 // WrapPredicate NUSW flag.
14561 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14562 if (Step->getValue()->getValue().isNonNegative())
14563 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14566 return ImpliedFlags;
14569 /// Union predicates don't get cached so create a dummy set ID for it.
14570 SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
14571 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14572 for (const auto *P : Preds)
14576 bool SCEVUnionPredicate::isAlwaysTrue() const {
14577 return all_of(Preds,
14578 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14581 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
14582 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14583 return all_of(Set->Preds,
14584 [this](const SCEVPredicate *I) { return this->implies(I); });
14586 return any_of(Preds,
14587 [N](const SCEVPredicate *I) { return I->implies(N); });
14590 void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
14591 for (const auto *Pred : Preds)
14592 Pred->print(OS, Depth);
14595 void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14596 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14597 for (const auto *Pred : Set->Preds)
14602 Preds.push_back(N);
14605 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
14608 SmallVector<const SCEVPredicate*, 4> Empty;
14609 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14612 void ScalarEvolution::registerUser(const SCEV *User,
14613 ArrayRef<const SCEV *> Ops) {
14614 for (const auto *Op : Ops)
14615 // We do not expect that forgetting cached data for SCEVConstants will ever
14616 // open any prospects for sharpening or introduce any correctness issues,
14617 // so we don't bother storing their dependencies.
14618 if (!isa<SCEVConstant>(Op))
14619 SCEVUsers[Op].insert(User);
14622 const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
14623 const SCEV *Expr = SE.getSCEV(V);
14624 RewriteEntry &Entry = RewriteMap[Expr];
14626 // If we already have an entry and the version matches, return it.
14627 if (Entry.second && Generation == Entry.first)
14628 return Entry.second;
14630 // We found an entry but it's stale. Rewrite the stale entry
14631 // according to the current predicate.
14633 Expr = Entry.second;
14635 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14636 Entry = {Generation, NewSCEV};
14641 const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
14642 if (!BackedgeCount) {
14643 SmallVector<const SCEVPredicate *, 4> Preds;
14644 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14645 for (const auto *P : Preds)
14648 return BackedgeCount;
14651 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
14652 if (Preds->implies(&Pred))
14655 auto &OldPreds = Preds->getPredicates();
14656 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14657 NewPreds.push_back(&Pred);
14658 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14659 updateGeneration();
14662 const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const {
14666 void PredicatedScalarEvolution::updateGeneration() {
14667 // If the generation number wrapped recompute everything.
14668 if (++Generation == 0) {
14669 for (auto &II : RewriteMap) {
14670 const SCEV *Rewritten = II.second.second;
14671 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14676 void PredicatedScalarEvolution::setNoOverflow(
14677 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
14678 const SCEV *Expr = getSCEV(V);
14679 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14681 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14683 // Clear the statically implied flags.
14684 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14685 addPredicate(*SE.getWrapPredicate(AR, Flags));
14687 auto II = FlagsMap.insert({V, Flags});
14689 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14692 bool PredicatedScalarEvolution::hasNoOverflow(
14693 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
14694 const SCEV *Expr = getSCEV(V);
14695 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14697 Flags = SCEVWrapPredicate::clearFlags(
14698 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14700 auto II = FlagsMap.find(V);
14702 if (II != FlagsMap.end())
14703 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14705 return Flags == SCEVWrapPredicate::IncrementAnyWrap;
14708 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
14709 const SCEV *Expr = this->getSCEV(V);
14710 SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
14711 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14716 for (const auto *P : NewPreds)
14719 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14723 PredicatedScalarEvolution::PredicatedScalarEvolution(
14724 const PredicatedScalarEvolution &Init)
14725 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
14726 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
14727 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
14728 for (auto I : Init.FlagsMap)
14729 FlagsMap.insert(I);
14732 void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
14734 for (auto *BB : L.getBlocks())
14735 for (auto &I : *BB) {
14736 if (!SE.isSCEVable(I.getType()))
14739 auto *Expr = SE.getSCEV(&I);
14740 auto II = RewriteMap.find(Expr);
14742 if (II == RewriteMap.end())
14745 // Don't print things that are not interesting.
14746 if (II->second.second == Expr)
14749 OS.indent(Depth) << "[PSE]" << I << ":\n";
14750 OS.indent(Depth + 2) << *Expr << "\n";
14751 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
14755 // Match the mathematical pattern A - (A / B) * B, where A and B can be
14756 // arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
14757 // for URem with constant power-of-2 second operands.
14758 // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
14759 // 4, A / B becomes X / 8).
14760 bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
14761 const SCEV *&RHS) {
14762 // Try to match 'zext (trunc A to iB) to iY', which is used
14763 // for URem with constant power-of-2 second operands. Make sure the size of
14764 // the operand A matches the size of the whole expressions.
14765 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
14766 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
14767 LHS = Trunc->getOperand();
14768 // Bail out if the type of the LHS is larger than the type of the
14769 // expression for now.
14770 if (getTypeSizeInBits(LHS->getType()) >
14771 getTypeSizeInBits(Expr->getType()))
14773 if (LHS->getType() != Expr->getType())
14774 LHS = getZeroExtendExpr(LHS, Expr->getType());
14775 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
14776 << getTypeSizeInBits(Trunc->getType()));
14779 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
14780 if (Add == nullptr || Add->getNumOperands() != 2)
14783 const SCEV *A = Add->getOperand(1);
14784 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
14786 if (Mul == nullptr)
14789 const auto MatchURemWithDivisor = [&](const SCEV *B) {
14790 // (SomeExpr + (-(SomeExpr / B) * B)).
14791 if (Expr == getURemExpr(A, B)) {
14799 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
14800 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
14801 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14802 MatchURemWithDivisor(Mul->getOperand(2));
14804 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
14805 if (Mul->getNumOperands() == 2)
14806 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14807 MatchURemWithDivisor(Mul->getOperand(0)) ||
14808 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
14809 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
14814 ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
14815 SmallVector<BasicBlock*, 16> ExitingBlocks;
14816 L->getExitingBlocks(ExitingBlocks);
14818 // Form an expression for the maximum exit count possible for this loop. We
14819 // merge the max and exact information to approximate a version of
14820 // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
14821 SmallVector<const SCEV*, 4> ExitCounts;
14822 for (BasicBlock *ExitingBB : ExitingBlocks) {
14823 const SCEV *ExitCount =
14824 getExitCount(L, ExitingBB, ScalarEvolution::SymbolicMaximum);
14825 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
14826 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
14827 "We should only have known counts for exiting blocks that "
14828 "dominate latch!");
14829 ExitCounts.push_back(ExitCount);
14832 if (ExitCounts.empty())
14833 return getCouldNotCompute();
14834 return getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
14837 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
14838 /// in the map. It skips AddRecExpr because we cannot guarantee that the
14839 /// replacement is loop invariant in the loop of the AddRec.
14840 class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
14841 const DenseMap<const SCEV *, const SCEV *> ⤅
14844 SCEVLoopGuardRewriter(ScalarEvolution &SE,
14845 DenseMap<const SCEV *, const SCEV *> &M)
14846 : SCEVRewriteVisitor(SE), Map(M) {}
14848 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
14850 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14851 auto I = Map.find(Expr);
14852 if (I == Map.end())
14857 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14858 auto I = Map.find(Expr);
14859 if (I == Map.end()) {
14860 // If we didn't find the extact ZExt expr in the map, check if there's an
14861 // entry for a smaller ZExt we can use instead.
14862 Type *Ty = Expr->getType();
14863 const SCEV *Op = Expr->getOperand(0);
14864 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
14865 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
14866 Bitwidth > Op->getType()->getScalarSizeInBits()) {
14867 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
14868 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
14869 auto I = Map.find(NarrowExt);
14870 if (I != Map.end())
14871 return SE.getZeroExtendExpr(I->second, Ty);
14872 Bitwidth = Bitwidth / 2;
14875 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
14881 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14882 auto I = Map.find(Expr);
14883 if (I == Map.end())
14884 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSignExtendExpr(
14889 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
14890 auto I = Map.find(Expr);
14891 if (I == Map.end())
14892 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitUMinExpr(Expr);
14896 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
14897 auto I = Map.find(Expr);
14898 if (I == Map.end())
14899 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
14904 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
14905 SmallVector<const SCEV *> ExprsToRewrite;
14906 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
14908 DenseMap<const SCEV *, const SCEV *>
14910 // WARNING: It is generally unsound to apply any wrap flags to the proposed
14911 // replacement SCEV which isn't directly implied by the structure of that
14912 // SCEV. In particular, using contextual facts to imply flags is *NOT*
14913 // legal. See the scoping rules for flags in the header to understand why.
14915 // If LHS is a constant, apply information to the other expression.
14916 if (isa<SCEVConstant>(LHS)) {
14917 std::swap(LHS, RHS);
14918 Predicate = CmpInst::getSwappedPredicate(Predicate);
14921 // Check for a condition of the form (-C1 + X < C2). InstCombine will
14922 // create this form when combining two checks of the form (X u< C2 + C1) and
14924 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
14925 &ExprsToRewrite]() {
14926 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
14927 if (!AddExpr || AddExpr->getNumOperands() != 2)
14930 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
14931 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
14932 auto *C2 = dyn_cast<SCEVConstant>(RHS);
14933 if (!C1 || !C2 || !LHSUnknown)
14937 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
14938 .sub(C1->getAPInt());
14940 // Bail out, unless we have a non-wrapping, monotonic range.
14941 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
14943 auto I = RewriteMap.find(LHSUnknown);
14944 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
14945 RewriteMap[LHSUnknown] = getUMaxExpr(
14946 getConstant(ExactRegion.getUnsignedMin()),
14947 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
14948 ExprsToRewrite.push_back(LHSUnknown);
14951 if (MatchRangeCheckIdiom())
14954 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
14955 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
14956 // the non-constant operand and in \p LHS the constant operand.
14957 auto IsMinMaxSCEVWithNonNegativeConstant =
14958 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
14959 const SCEV *&RHS) {
14960 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
14961 if (MinMax->getNumOperands() != 2)
14963 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
14964 if (C->getAPInt().isNegative())
14966 SCTy = MinMax->getSCEVType();
14967 LHS = MinMax->getOperand(0);
14968 RHS = MinMax->getOperand(1);
14975 // Checks whether Expr is a non-negative constant, and Divisor is a positive
14976 // constant, and returns their APInt in ExprVal and in DivisorVal.
14977 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
14978 APInt &ExprVal, APInt &DivisorVal) {
14979 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
14980 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
14981 if (!ConstExpr || !ConstDivisor)
14983 ExprVal = ConstExpr->getAPInt();
14984 DivisorVal = ConstDivisor->getAPInt();
14985 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
14988 // Return a new SCEV that modifies \p Expr to the closest number divides by
14989 // \p Divisor and greater or equal than Expr.
14990 // For now, only handle constant Expr and Divisor.
14991 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
14992 const SCEV *Divisor) {
14995 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
14997 APInt Rem = ExprVal.urem(DivisorVal);
14999 // return the SCEV: Expr + Divisor - Expr % Divisor
15000 return getConstant(ExprVal + DivisorVal - Rem);
15004 // Return a new SCEV that modifies \p Expr to the closest number divides by
15005 // \p Divisor and less or equal than Expr.
15006 // For now, only handle constant Expr and Divisor.
15007 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15008 const SCEV *Divisor) {
15011 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15013 APInt Rem = ExprVal.urem(DivisorVal);
15014 // return the SCEV: Expr - Expr % Divisor
15015 return getConstant(ExprVal - Rem);
15018 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15019 // recursively. This is done by aligning up/down the constant value to the
15021 std::function<const SCEV *(const SCEV *, const SCEV *)>
15022 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15023 const SCEV *Divisor) {
15024 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15026 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15030 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15031 assert(isKnownNonNegative(MinMaxLHS) &&
15032 "Expected non-negative operand!");
15033 auto *DivisibleExpr =
15034 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15035 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15036 SmallVector<const SCEV *> Ops = {
15037 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15038 return getMinMaxExpr(SCTy, Ops);
15041 // If we have LHS == 0, check if LHS is computing a property of some unknown
15042 // SCEV %v which we can rewrite %v to express explicitly.
15043 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15044 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15045 RHSC->getValue()->isNullValue()) {
15046 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15047 // explicitly express that.
15048 const SCEV *URemLHS = nullptr;
15049 const SCEV *URemRHS = nullptr;
15050 if (matchURem(LHS, URemLHS, URemRHS)) {
15051 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15052 auto I = RewriteMap.find(LHSUnknown);
15053 const SCEV *RewrittenLHS =
15054 I != RewriteMap.end() ? I->second : LHSUnknown;
15055 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15056 const auto *Multiple =
15057 getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15058 RewriteMap[LHSUnknown] = Multiple;
15059 ExprsToRewrite.push_back(LHSUnknown);
15065 // Do not apply information for constants or if RHS contains an AddRec.
15066 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
15069 // If RHS is SCEVUnknown, make sure the information is applied to it.
15070 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15071 std::swap(LHS, RHS);
15072 Predicate = CmpInst::getSwappedPredicate(Predicate);
15075 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15076 // and \p FromRewritten are the same (i.e. there has been no rewrite
15077 // registered for \p From), then puts this value in the list of rewritten
15079 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15081 if (From == FromRewritten)
15082 ExprsToRewrite.push_back(From);
15083 RewriteMap[From] = To;
15086 // Checks whether \p S has already been rewritten. In that case returns the
15087 // existing rewrite because we want to chain further rewrites onto the
15088 // already rewritten value. Otherwise returns \p S.
15089 auto GetMaybeRewritten = [&](const SCEV *S) {
15090 auto I = RewriteMap.find(S);
15091 return I != RewriteMap.end() ? I->second : S;
15094 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15095 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15096 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15097 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15098 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15099 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15101 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15102 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15103 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15104 if (Mul->getNumOperands() != 2)
15106 auto *MulLHS = Mul->getOperand(0);
15107 auto *MulRHS = Mul->getOperand(1);
15108 if (isa<SCEVConstant>(MulLHS))
15109 std::swap(MulLHS, MulRHS);
15110 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15111 if (Div->getOperand(1) == MulRHS) {
15112 DividesBy = MulRHS;
15116 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15117 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15118 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15122 // Return true if Expr known to divide by \p DividesBy.
15123 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15124 [&](const SCEV *Expr, const SCEV *DividesBy) {
15125 if (getURemExpr(Expr, DividesBy)->isZero())
15127 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15128 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15129 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15133 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15134 const SCEV *DividesBy = nullptr;
15135 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15136 // Check that the whole expression is divided by DividesBy
15138 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15140 // Collect rewrites for LHS and its transitive operands based on the
15142 // For min/max expressions, also apply the guard to its operands:
15143 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15144 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15145 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15146 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15148 // We cannot express strict predicates in SCEV, so instead we replace them
15149 // with non-strict ones against plus or minus one of RHS depending on the
15151 const SCEV *One = getOne(RHS->getType());
15152 switch (Predicate) {
15153 case CmpInst::ICMP_ULT:
15154 if (RHS->getType()->isPointerTy())
15156 RHS = getUMaxExpr(RHS, One);
15158 case CmpInst::ICMP_SLT: {
15159 RHS = getMinusSCEV(RHS, One);
15160 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15163 case CmpInst::ICMP_UGT:
15164 case CmpInst::ICMP_SGT:
15165 RHS = getAddExpr(RHS, One);
15166 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15168 case CmpInst::ICMP_ULE:
15169 case CmpInst::ICMP_SLE:
15170 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15172 case CmpInst::ICMP_UGE:
15173 case CmpInst::ICMP_SGE:
15174 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15180 SmallVector<const SCEV *, 16> Worklist(1, LHS);
15181 SmallPtrSet<const SCEV *, 16> Visited;
15183 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15184 append_range(Worklist, S->operands());
15187 while (!Worklist.empty()) {
15188 const SCEV *From = Worklist.pop_back_val();
15189 if (isa<SCEVConstant>(From))
15191 if (!Visited.insert(From).second)
15193 const SCEV *FromRewritten = GetMaybeRewritten(From);
15194 const SCEV *To = nullptr;
15196 switch (Predicate) {
15197 case CmpInst::ICMP_ULT:
15198 case CmpInst::ICMP_ULE:
15199 To = getUMinExpr(FromRewritten, RHS);
15200 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15201 EnqueueOperands(UMax);
15203 case CmpInst::ICMP_SLT:
15204 case CmpInst::ICMP_SLE:
15205 To = getSMinExpr(FromRewritten, RHS);
15206 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15207 EnqueueOperands(SMax);
15209 case CmpInst::ICMP_UGT:
15210 case CmpInst::ICMP_UGE:
15211 To = getUMaxExpr(FromRewritten, RHS);
15212 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15213 EnqueueOperands(UMin);
15215 case CmpInst::ICMP_SGT:
15216 case CmpInst::ICMP_SGE:
15217 To = getSMaxExpr(FromRewritten, RHS);
15218 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15219 EnqueueOperands(SMin);
15221 case CmpInst::ICMP_EQ:
15222 if (isa<SCEVConstant>(RHS))
15225 case CmpInst::ICMP_NE:
15226 if (isa<SCEVConstant>(RHS) &&
15227 cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15228 const SCEV *OneAlignedUp =
15229 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15230 To = getUMaxExpr(FromRewritten, OneAlignedUp);
15238 AddRewrite(From, FromRewritten, To);
15242 BasicBlock *Header = L->getHeader();
15243 SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
15244 // First, collect information from assumptions dominating the loop.
15245 for (auto &AssumeVH : AC.assumptions()) {
15248 auto *AssumeI = cast<CallInst>(AssumeVH);
15249 if (!DT.dominates(AssumeI, Header))
15251 Terms.emplace_back(AssumeI->getOperand(0), true);
15254 // Second, collect information from llvm.experimental.guards dominating the loop.
15255 auto *GuardDecl = F.getParent()->getFunction(
15256 Intrinsic::getName(Intrinsic::experimental_guard));
15258 for (const auto *GU : GuardDecl->users())
15259 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15260 if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
15261 Terms.emplace_back(Guard->getArgOperand(0), true);
15263 // Third, collect conditions from dominating branches. Starting at the loop
15264 // predecessor, climb up the predecessor chain, as long as there are
15265 // predecessors that can be found that have unique successors leading to the
15266 // original header.
15267 // TODO: share this logic with isLoopEntryGuardedByCond.
15268 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15269 L->getLoopPredecessor(), Header);
15270 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15272 const BranchInst *LoopEntryPredicate =
15273 dyn_cast<BranchInst>(Pair.first->getTerminator());
15274 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15277 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15278 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15281 // Now apply the information from the collected conditions to RewriteMap.
15282 // Conditions are processed in reverse order, so the earliest conditions is
15283 // processed first. This ensures the SCEVs with the shortest dependency chains
15284 // are constructed first.
15285 DenseMap<const SCEV *, const SCEV *> RewriteMap;
15286 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15287 SmallVector<Value *, 8> Worklist;
15288 SmallPtrSet<Value *, 8> Visited;
15289 Worklist.push_back(Term);
15290 while (!Worklist.empty()) {
15291 Value *Cond = Worklist.pop_back_val();
15292 if (!Visited.insert(Cond).second)
15295 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15297 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15298 const auto *LHS = getSCEV(Cmp->getOperand(0));
15299 const auto *RHS = getSCEV(Cmp->getOperand(1));
15300 CollectCondition(Predicate, LHS, RHS, RewriteMap);
15305 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15306 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15307 Worklist.push_back(L);
15308 Worklist.push_back(R);
15313 if (RewriteMap.empty())
15316 // Now that all rewrite information is collect, rewrite the collected
15317 // expressions with the information in the map. This applies information to
15318 // sub-expressions.
15319 if (ExprsToRewrite.size() > 1) {
15320 for (const SCEV *Expr : ExprsToRewrite) {
15321 const SCEV *RewriteTo = RewriteMap[Expr];
15322 RewriteMap.erase(Expr);
15323 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15324 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
15328 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15329 return Rewriter.visit(Expr);