Home | History | Annotate | Line # | Download | only in Analysis
      1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file defines the classes used to represent and build scalar expressions.
     10 //
     11 //===----------------------------------------------------------------------===//
     12 
     13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
     14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
     15 
     16 #include "llvm/ADT/DenseMap.h"
     17 #include "llvm/ADT/FoldingSet.h"
     18 #include "llvm/ADT/SmallPtrSet.h"
     19 #include "llvm/ADT/SmallVector.h"
     20 #include "llvm/ADT/iterator_range.h"
     21 #include "llvm/Analysis/ScalarEvolution.h"
     22 #include "llvm/IR/Constants.h"
     23 #include "llvm/IR/Value.h"
     24 #include "llvm/IR/ValueHandle.h"
     25 #include "llvm/Support/Casting.h"
     26 #include "llvm/Support/ErrorHandling.h"
     27 #include <cassert>
     28 #include <cstddef>
     29 
     30 namespace llvm {
     31 
     32 class APInt;
     33 class Constant;
     34 class ConstantRange;
     35 class Loop;
     36 class Type;
     37 
     38   enum SCEVTypes : unsigned short {
     39     // These should be ordered in terms of increasing complexity to make the
     40     // folders simpler.
     41     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
     42     scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUMinExpr, scSMinExpr,
     43     scPtrToInt, scUnknown, scCouldNotCompute
     44   };
     45 
     46   /// This class represents a constant integer value.
     47   class SCEVConstant : public SCEV {
     48     friend class ScalarEvolution;
     49 
     50     ConstantInt *V;
     51 
     52     SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) :
     53       SCEV(ID, scConstant, 1), V(v) {}
     54 
     55   public:
     56     ConstantInt *getValue() const { return V; }
     57     const APInt &getAPInt() const { return getValue()->getValue(); }
     58 
     59     Type *getType() const { return V->getType(); }
     60 
     61     /// Methods for support type inquiry through isa, cast, and dyn_cast:
     62     static bool classof(const SCEV *S) {
     63       return S->getSCEVType() == scConstant;
     64     }
     65   };
     66 
     67   inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
     68     APInt Size(16, 1);
     69     for (auto *Arg : Args)
     70       Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
     71     return (unsigned short)Size.getZExtValue();
     72   }
     73 
     74   /// This is the base class for unary cast operator classes.
     75   class SCEVCastExpr : public SCEV {
     76   protected:
     77     std::array<const SCEV *, 1> Operands;
     78     Type *Ty;
     79 
     80     SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
     81                  Type *ty);
     82 
     83   public:
     84     const SCEV *getOperand() const { return Operands[0]; }
     85     const SCEV *getOperand(unsigned i) const {
     86       assert(i == 0 && "Operand index out of range!");
     87       return Operands[0];
     88     }
     89     using op_iterator = std::array<const SCEV *, 1>::const_iterator;
     90     using op_range = iterator_range<op_iterator>;
     91 
     92     op_range operands() const {
     93       return make_range(Operands.begin(), Operands.end());
     94     }
     95     size_t getNumOperands() const { return 1; }
     96     Type *getType() const { return Ty; }
     97 
     98     /// Methods for support type inquiry through isa, cast, and dyn_cast:
     99     static bool classof(const SCEV *S) {
    100       return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
    101              S->getSCEVType() == scZeroExtend ||
    102              S->getSCEVType() == scSignExtend;
    103     }
    104   };
    105 
    106   /// This class represents a cast from a pointer to a pointer-sized integer
    107   /// value.
    108   class SCEVPtrToIntExpr : public SCEVCastExpr {
    109     friend class ScalarEvolution;
    110 
    111     SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
    112 
    113   public:
    114     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    115     static bool classof(const SCEV *S) {
    116       return S->getSCEVType() == scPtrToInt;
    117     }
    118   };
    119 
    120   /// This is the base class for unary integral cast operator classes.
    121   class SCEVIntegralCastExpr : public SCEVCastExpr {
    122   protected:
    123     SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
    124                          const SCEV *op, Type *ty);
    125 
    126   public:
    127     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    128     static bool classof(const SCEV *S) {
    129       return S->getSCEVType() == scTruncate ||
    130              S->getSCEVType() == scZeroExtend ||
    131              S->getSCEVType() == scSignExtend;
    132     }
    133   };
    134 
    135   /// This class represents a truncation of an integer value to a
    136   /// smaller integer value.
    137   class SCEVTruncateExpr : public SCEVIntegralCastExpr {
    138     friend class ScalarEvolution;
    139 
    140     SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
    141                      const SCEV *op, Type *ty);
    142 
    143   public:
    144     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    145     static bool classof(const SCEV *S) {
    146       return S->getSCEVType() == scTruncate;
    147     }
    148   };
    149 
    150   /// This class represents a zero extension of a small integer value
    151   /// to a larger integer value.
    152   class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
    153     friend class ScalarEvolution;
    154 
    155     SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
    156                        const SCEV *op, Type *ty);
    157 
    158   public:
    159     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    160     static bool classof(const SCEV *S) {
    161       return S->getSCEVType() == scZeroExtend;
    162     }
    163   };
    164 
    165   /// This class represents a sign extension of a small integer value
    166   /// to a larger integer value.
    167   class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
    168     friend class ScalarEvolution;
    169 
    170     SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
    171                        const SCEV *op, Type *ty);
    172 
    173   public:
    174     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    175     static bool classof(const SCEV *S) {
    176       return S->getSCEVType() == scSignExtend;
    177     }
    178   };
    179 
    180   /// This node is a base class providing common functionality for
    181   /// n'ary operators.
    182   class SCEVNAryExpr : public SCEV {
    183   protected:
    184     // Since SCEVs are immutable, ScalarEvolution allocates operand
    185     // arrays with its SCEVAllocator, so this class just needs a simple
    186     // pointer rather than a more elaborate vector-like data structure.
    187     // This also avoids the need for a non-trivial destructor.
    188     const SCEV *const *Operands;
    189     size_t NumOperands;
    190 
    191     SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
    192                  const SCEV *const *O, size_t N)
    193         : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
    194           NumOperands(N) {}
    195 
    196   public:
    197     size_t getNumOperands() const { return NumOperands; }
    198 
    199     const SCEV *getOperand(unsigned i) const {
    200       assert(i < NumOperands && "Operand index out of range!");
    201       return Operands[i];
    202     }
    203 
    204     using op_iterator = const SCEV *const *;
    205     using op_range = iterator_range<op_iterator>;
    206 
    207     op_iterator op_begin() const { return Operands; }
    208     op_iterator op_end() const { return Operands + NumOperands; }
    209     op_range operands() const {
    210       return make_range(op_begin(), op_end());
    211     }
    212 
    213     Type *getType() const { return getOperand(0)->getType(); }
    214 
    215     NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
    216       return (NoWrapFlags)(SubclassData & Mask);
    217     }
    218 
    219     bool hasNoUnsignedWrap() const {
    220       return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
    221     }
    222 
    223     bool hasNoSignedWrap() const {
    224       return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
    225     }
    226 
    227     bool hasNoSelfWrap() const {
    228       return getNoWrapFlags(FlagNW) != FlagAnyWrap;
    229     }
    230 
    231     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    232     static bool classof(const SCEV *S) {
    233       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
    234              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
    235              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
    236              S->getSCEVType() == scAddRecExpr;
    237     }
    238   };
    239 
    240   /// This node is the base class for n'ary commutative operators.
    241   class SCEVCommutativeExpr : public SCEVNAryExpr {
    242   protected:
    243     SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,
    244                         enum SCEVTypes T, const SCEV *const *O, size_t N)
    245       : SCEVNAryExpr(ID, T, O, N) {}
    246 
    247   public:
    248     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    249     static bool classof(const SCEV *S) {
    250       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
    251              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
    252              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
    253     }
    254 
    255     /// Set flags for a non-recurrence without clearing previously set flags.
    256     void setNoWrapFlags(NoWrapFlags Flags) {
    257       SubclassData |= Flags;
    258     }
    259   };
    260 
    261   /// This node represents an addition of some number of SCEVs.
    262   class SCEVAddExpr : public SCEVCommutativeExpr {
    263     friend class ScalarEvolution;
    264 
    265     Type *Ty;
    266 
    267     SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
    268         : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
    269       auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
    270         return Op->getType()->isPointerTy();
    271       });
    272       if (FirstPointerTypedOp != operands().end())
    273         Ty = (*FirstPointerTypedOp)->getType();
    274       else
    275         Ty = getOperand(0)->getType();
    276     }
    277 
    278   public:
    279     Type *getType() const { return Ty; }
    280 
    281     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    282     static bool classof(const SCEV *S) {
    283       return S->getSCEVType() == scAddExpr;
    284     }
    285   };
    286 
    287   /// This node represents multiplication of some number of SCEVs.
    288   class SCEVMulExpr : public SCEVCommutativeExpr {
    289     friend class ScalarEvolution;
    290 
    291     SCEVMulExpr(const FoldingSetNodeIDRef ID,
    292                 const SCEV *const *O, size_t N)
    293       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
    294 
    295   public:
    296     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    297     static bool classof(const SCEV *S) {
    298       return S->getSCEVType() == scMulExpr;
    299     }
    300   };
    301 
    302   /// This class represents a binary unsigned division operation.
    303   class SCEVUDivExpr : public SCEV {
    304     friend class ScalarEvolution;
    305 
    306     std::array<const SCEV *, 2> Operands;
    307 
    308     SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
    309         : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
    310         Operands[0] = lhs;
    311         Operands[1] = rhs;
    312       }
    313 
    314   public:
    315     const SCEV *getLHS() const { return Operands[0]; }
    316     const SCEV *getRHS() const { return Operands[1]; }
    317     size_t getNumOperands() const { return 2; }
    318     const SCEV *getOperand(unsigned i) const {
    319       assert((i == 0 || i == 1) && "Operand index out of range!");
    320       return i == 0 ? getLHS() : getRHS();
    321     }
    322 
    323     using op_iterator = std::array<const SCEV *, 2>::const_iterator;
    324     using op_range = iterator_range<op_iterator>;
    325     op_range operands() const {
    326       return make_range(Operands.begin(), Operands.end());
    327     }
    328 
    329     Type *getType() const {
    330       // In most cases the types of LHS and RHS will be the same, but in some
    331       // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
    332       // depend on the type for correctness, but handling types carefully can
    333       // avoid extra casts in the SCEVExpander. The LHS is more likely to be
    334       // a pointer type than the RHS, so use the RHS' type here.
    335       return getRHS()->getType();
    336     }
    337 
    338     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    339     static bool classof(const SCEV *S) {
    340       return S->getSCEVType() == scUDivExpr;
    341     }
    342   };
    343 
    344   /// This node represents a polynomial recurrence on the trip count
    345   /// of the specified loop.  This is the primary focus of the
    346   /// ScalarEvolution framework; all the other SCEV subclasses are
    347   /// mostly just supporting infrastructure to allow SCEVAddRecExpr
    348   /// expressions to be created and analyzed.
    349   ///
    350   /// All operands of an AddRec are required to be loop invariant.
    351   ///
    352   class SCEVAddRecExpr : public SCEVNAryExpr {
    353     friend class ScalarEvolution;
    354 
    355     const Loop *L;
    356 
    357     SCEVAddRecExpr(const FoldingSetNodeIDRef ID,
    358                    const SCEV *const *O, size_t N, const Loop *l)
    359       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
    360 
    361   public:
    362     const SCEV *getStart() const { return Operands[0]; }
    363     const Loop *getLoop() const { return L; }
    364 
    365     /// Constructs and returns the recurrence indicating how much this
    366     /// expression steps by.  If this is a polynomial of degree N, it
    367     /// returns a chrec of degree N-1.  We cannot determine whether
    368     /// the step recurrence has self-wraparound.
    369     const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
    370       if (isAffine()) return getOperand(1);
    371       return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1,
    372                                                            op_end()),
    373                               getLoop(), FlagAnyWrap);
    374     }
    375 
    376     /// Return true if this represents an expression A + B*x where A
    377     /// and B are loop invariant values.
    378     bool isAffine() const {
    379       // We know that the start value is invariant.  This expression is thus
    380       // affine iff the step is also invariant.
    381       return getNumOperands() == 2;
    382     }
    383 
    384     /// Return true if this represents an expression A + B*x + C*x^2
    385     /// where A, B and C are loop invariant values.  This corresponds
    386     /// to an addrec of the form {L,+,M,+,N}
    387     bool isQuadratic() const {
    388       return getNumOperands() == 3;
    389     }
    390 
    391     /// Set flags for a recurrence without clearing any previously set flags.
    392     /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
    393     /// to make it easier to propagate flags.
    394     void setNoWrapFlags(NoWrapFlags Flags) {
    395       if (Flags & (FlagNUW | FlagNSW))
    396         Flags = ScalarEvolution::setFlags(Flags, FlagNW);
    397       SubclassData |= Flags;
    398     }
    399 
    400     /// Return the value of this chain of recurrences at the specified
    401     /// iteration number.
    402     const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
    403 
    404     /// Return the number of iterations of this loop that produce
    405     /// values in the specified constant range.  Another way of
    406     /// looking at this is that it returns the first iteration number
    407     /// where the value is not in the condition, thus computing the
    408     /// exit count.  If the iteration count can't be computed, an
    409     /// instance of SCEVCouldNotCompute is returned.
    410     const SCEV *getNumIterationsInRange(const ConstantRange &Range,
    411                                         ScalarEvolution &SE) const;
    412 
    413     /// Return an expression representing the value of this expression
    414     /// one iteration of the loop ahead.
    415     const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
    416 
    417     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    418     static bool classof(const SCEV *S) {
    419       return S->getSCEVType() == scAddRecExpr;
    420     }
    421   };
    422 
    423   /// This node is the base class min/max selections.
    424   class SCEVMinMaxExpr : public SCEVCommutativeExpr {
    425     friend class ScalarEvolution;
    426 
    427     static bool isMinMaxType(enum SCEVTypes T) {
    428       return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
    429              T == scUMinExpr;
    430     }
    431 
    432   protected:
    433     /// Note: Constructing subclasses via this constructor is allowed
    434     SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
    435                    const SCEV *const *O, size_t N)
    436         : SCEVCommutativeExpr(ID, T, O, N) {
    437       assert(isMinMaxType(T));
    438       // Min and max never overflow
    439       setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
    440     }
    441 
    442   public:
    443     static bool classof(const SCEV *S) {
    444       return isMinMaxType(S->getSCEVType());
    445     }
    446 
    447     static enum SCEVTypes negate(enum SCEVTypes T) {
    448       switch (T) {
    449       case scSMaxExpr:
    450         return scSMinExpr;
    451       case scSMinExpr:
    452         return scSMaxExpr;
    453       case scUMaxExpr:
    454         return scUMinExpr;
    455       case scUMinExpr:
    456         return scUMaxExpr;
    457       default:
    458         llvm_unreachable("Not a min or max SCEV type!");
    459       }
    460     }
    461   };
    462 
    463   /// This class represents a signed maximum selection.
    464   class SCEVSMaxExpr : public SCEVMinMaxExpr {
    465     friend class ScalarEvolution;
    466 
    467     SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
    468         : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
    469 
    470   public:
    471     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    472     static bool classof(const SCEV *S) {
    473       return S->getSCEVType() == scSMaxExpr;
    474     }
    475   };
    476 
    477   /// This class represents an unsigned maximum selection.
    478   class SCEVUMaxExpr : public SCEVMinMaxExpr {
    479     friend class ScalarEvolution;
    480 
    481     SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
    482         : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
    483 
    484   public:
    485     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    486     static bool classof(const SCEV *S) {
    487       return S->getSCEVType() == scUMaxExpr;
    488     }
    489   };
    490 
    491   /// This class represents a signed minimum selection.
    492   class SCEVSMinExpr : public SCEVMinMaxExpr {
    493     friend class ScalarEvolution;
    494 
    495     SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
    496         : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
    497 
    498   public:
    499     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    500     static bool classof(const SCEV *S) {
    501       return S->getSCEVType() == scSMinExpr;
    502     }
    503   };
    504 
    505   /// This class represents an unsigned minimum selection.
    506   class SCEVUMinExpr : public SCEVMinMaxExpr {
    507     friend class ScalarEvolution;
    508 
    509     SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
    510         : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
    511 
    512   public:
    513     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    514     static bool classof(const SCEV *S) {
    515       return S->getSCEVType() == scUMinExpr;
    516     }
    517   };
    518 
    519   /// This means that we are dealing with an entirely unknown SCEV
    520   /// value, and only represent it as its LLVM Value.  This is the
    521   /// "bottom" value for the analysis.
    522   class SCEVUnknown final : public SCEV, private CallbackVH {
    523     friend class ScalarEvolution;
    524 
    525     /// The parent ScalarEvolution value. This is used to update the
    526     /// parent's maps when the value associated with a SCEVUnknown is
    527     /// deleted or RAUW'd.
    528     ScalarEvolution *SE;
    529 
    530     /// The next pointer in the linked list of all SCEVUnknown
    531     /// instances owned by a ScalarEvolution.
    532     SCEVUnknown *Next;
    533 
    534     SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V,
    535                 ScalarEvolution *se, SCEVUnknown *next) :
    536       SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
    537 
    538     // Implement CallbackVH.
    539     void deleted() override;
    540     void allUsesReplacedWith(Value *New) override;
    541 
    542   public:
    543     Value *getValue() const { return getValPtr(); }
    544 
    545     /// @{
    546     /// Test whether this is a special constant representing a type
    547     /// size, alignment, or field offset in a target-independent
    548     /// manner, and hasn't happened to have been folded with other
    549     /// operations into something unrecognizable. This is mainly only
    550     /// useful for pretty-printing and other situations where it isn't
    551     /// absolutely required for these to succeed.
    552     bool isSizeOf(Type *&AllocTy) const;
    553     bool isAlignOf(Type *&AllocTy) const;
    554     bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
    555     /// @}
    556 
    557     Type *getType() const { return getValPtr()->getType(); }
    558 
    559     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    560     static bool classof(const SCEV *S) {
    561       return S->getSCEVType() == scUnknown;
    562     }
    563   };
    564 
    565   /// This class defines a simple visitor class that may be used for
    566   /// various SCEV analysis purposes.
    567   template<typename SC, typename RetVal=void>
    568   struct SCEVVisitor {
    569     RetVal visit(const SCEV *S) {
    570       switch (S->getSCEVType()) {
    571       case scConstant:
    572         return ((SC*)this)->visitConstant((const SCEVConstant*)S);
    573       case scPtrToInt:
    574         return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
    575       case scTruncate:
    576         return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S);
    577       case scZeroExtend:
    578         return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S);
    579       case scSignExtend:
    580         return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S);
    581       case scAddExpr:
    582         return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S);
    583       case scMulExpr:
    584         return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S);
    585       case scUDivExpr:
    586         return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S);
    587       case scAddRecExpr:
    588         return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S);
    589       case scSMaxExpr:
    590         return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S);
    591       case scUMaxExpr:
    592         return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S);
    593       case scSMinExpr:
    594         return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
    595       case scUMinExpr:
    596         return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
    597       case scUnknown:
    598         return ((SC*)this)->visitUnknown((const SCEVUnknown*)S);
    599       case scCouldNotCompute:
    600         return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S);
    601       }
    602       llvm_unreachable("Unknown SCEV kind!");
    603     }
    604 
    605     RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
    606       llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
    607     }
    608   };
    609 
    610   /// Visit all nodes in the expression tree using worklist traversal.
    611   ///
    612   /// Visitor implements:
    613   ///   // return true to follow this node.
    614   ///   bool follow(const SCEV *S);
    615   ///   // return true to terminate the search.
    616   ///   bool isDone();
    617   template<typename SV>
    618   class SCEVTraversal {
    619     SV &Visitor;
    620     SmallVector<const SCEV *, 8> Worklist;
    621     SmallPtrSet<const SCEV *, 8> Visited;
    622 
    623     void push(const SCEV *S) {
    624       if (Visited.insert(S).second && Visitor.follow(S))
    625         Worklist.push_back(S);
    626     }
    627 
    628   public:
    629     SCEVTraversal(SV& V): Visitor(V) {}
    630 
    631     void visitAll(const SCEV *Root) {
    632       push(Root);
    633       while (!Worklist.empty() && !Visitor.isDone()) {
    634         const SCEV *S = Worklist.pop_back_val();
    635 
    636         switch (S->getSCEVType()) {
    637         case scConstant:
    638         case scUnknown:
    639           continue;
    640         case scPtrToInt:
    641         case scTruncate:
    642         case scZeroExtend:
    643         case scSignExtend:
    644           push(cast<SCEVCastExpr>(S)->getOperand());
    645           continue;
    646         case scAddExpr:
    647         case scMulExpr:
    648         case scSMaxExpr:
    649         case scUMaxExpr:
    650         case scSMinExpr:
    651         case scUMinExpr:
    652         case scAddRecExpr:
    653           for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
    654             push(Op);
    655           continue;
    656         case scUDivExpr: {
    657           const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
    658           push(UDiv->getLHS());
    659           push(UDiv->getRHS());
    660           continue;
    661         }
    662         case scCouldNotCompute:
    663           llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
    664         }
    665         llvm_unreachable("Unknown SCEV kind!");
    666       }
    667     }
    668   };
    669 
    670   /// Use SCEVTraversal to visit all nodes in the given expression tree.
    671   template<typename SV>
    672   void visitAll(const SCEV *Root, SV& Visitor) {
    673     SCEVTraversal<SV> T(Visitor);
    674     T.visitAll(Root);
    675   }
    676 
    677   /// Return true if any node in \p Root satisfies the predicate \p Pred.
    678   template <typename PredTy>
    679   bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
    680     struct FindClosure {
    681       bool Found = false;
    682       PredTy Pred;
    683 
    684       FindClosure(PredTy Pred) : Pred(Pred) {}
    685 
    686       bool follow(const SCEV *S) {
    687         if (!Pred(S))
    688           return true;
    689 
    690         Found = true;
    691         return false;
    692       }
    693 
    694       bool isDone() const { return Found; }
    695     };
    696 
    697     FindClosure FC(Pred);
    698     visitAll(Root, FC);
    699     return FC.Found;
    700   }
    701 
    702   /// This visitor recursively visits a SCEV expression and re-writes it.
    703   /// The result from each visit is cached, so it will return the same
    704   /// SCEV for the same input.
    705   template<typename SC>
    706   class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
    707   protected:
    708     ScalarEvolution &SE;
    709     // Memoize the result of each visit so that we only compute once for
    710     // the same input SCEV. This is to avoid redundant computations when
    711     // a SCEV is referenced by multiple SCEVs. Without memoization, this
    712     // visit algorithm would have exponential time complexity in the worst
    713     // case, causing the compiler to hang on certain tests.
    714     DenseMap<const SCEV *, const SCEV *> RewriteResults;
    715 
    716   public:
    717     SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
    718 
    719     const SCEV *visit(const SCEV *S) {
    720       auto It = RewriteResults.find(S);
    721       if (It != RewriteResults.end())
    722         return It->second;
    723       auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
    724       auto Result = RewriteResults.try_emplace(S, Visited);
    725       assert(Result.second && "Should insert a new entry");
    726       return Result.first->second;
    727     }
    728 
    729     const SCEV *visitConstant(const SCEVConstant *Constant) {
    730       return Constant;
    731     }
    732 
    733     const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
    734       const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
    735       return Operand == Expr->getOperand()
    736                  ? Expr
    737                  : SE.getPtrToIntExpr(Operand, Expr->getType());
    738     }
    739 
    740     const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
    741       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
    742       return Operand == Expr->getOperand()
    743                  ? Expr
    744                  : SE.getTruncateExpr(Operand, Expr->getType());
    745     }
    746 
    747     const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
    748       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
    749       return Operand == Expr->getOperand()
    750                  ? Expr
    751                  : SE.getZeroExtendExpr(Operand, Expr->getType());
    752     }
    753 
    754     const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
    755       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
    756       return Operand == Expr->getOperand()
    757                  ? Expr
    758                  : SE.getSignExtendExpr(Operand, Expr->getType());
    759     }
    760 
    761     const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
    762       SmallVector<const SCEV *, 2> Operands;
    763       bool Changed = false;
    764       for (auto *Op : Expr->operands()) {
    765         Operands.push_back(((SC*)this)->visit(Op));
    766         Changed |= Op != Operands.back();
    767       }
    768       return !Changed ? Expr : SE.getAddExpr(Operands);
    769     }
    770 
    771     const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
    772       SmallVector<const SCEV *, 2> Operands;
    773       bool Changed = false;
    774       for (auto *Op : Expr->operands()) {
    775         Operands.push_back(((SC*)this)->visit(Op));
    776         Changed |= Op != Operands.back();
    777       }
    778       return !Changed ? Expr : SE.getMulExpr(Operands);
    779     }
    780 
    781     const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
    782       auto *LHS = ((SC *)this)->visit(Expr->getLHS());
    783       auto *RHS = ((SC *)this)->visit(Expr->getRHS());
    784       bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
    785       return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
    786     }
    787 
    788     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
    789       SmallVector<const SCEV *, 2> Operands;
    790       bool Changed = false;
    791       for (auto *Op : Expr->operands()) {
    792         Operands.push_back(((SC*)this)->visit(Op));
    793         Changed |= Op != Operands.back();
    794       }
    795       return !Changed ? Expr
    796                       : SE.getAddRecExpr(Operands, Expr->getLoop(),
    797                                          Expr->getNoWrapFlags());
    798     }
    799 
    800     const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
    801       SmallVector<const SCEV *, 2> Operands;
    802       bool Changed = false;
    803       for (auto *Op : Expr->operands()) {
    804         Operands.push_back(((SC *)this)->visit(Op));
    805         Changed |= Op != Operands.back();
    806       }
    807       return !Changed ? Expr : SE.getSMaxExpr(Operands);
    808     }
    809 
    810     const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
    811       SmallVector<const SCEV *, 2> Operands;
    812       bool Changed = false;
    813       for (auto *Op : Expr->operands()) {
    814         Operands.push_back(((SC*)this)->visit(Op));
    815         Changed |= Op != Operands.back();
    816       }
    817       return !Changed ? Expr : SE.getUMaxExpr(Operands);
    818     }
    819 
    820     const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
    821       SmallVector<const SCEV *, 2> Operands;
    822       bool Changed = false;
    823       for (auto *Op : Expr->operands()) {
    824         Operands.push_back(((SC *)this)->visit(Op));
    825         Changed |= Op != Operands.back();
    826       }
    827       return !Changed ? Expr : SE.getSMinExpr(Operands);
    828     }
    829 
    830     const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
    831       SmallVector<const SCEV *, 2> Operands;
    832       bool Changed = false;
    833       for (auto *Op : Expr->operands()) {
    834         Operands.push_back(((SC *)this)->visit(Op));
    835         Changed |= Op != Operands.back();
    836       }
    837       return !Changed ? Expr : SE.getUMinExpr(Operands);
    838     }
    839 
    840     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
    841       return Expr;
    842     }
    843 
    844     const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
    845       return Expr;
    846     }
    847   };
    848 
    849   using ValueToValueMap = DenseMap<const Value *, Value *>;
    850   using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
    851 
    852   /// The SCEVParameterRewriter takes a scalar evolution expression and updates
    853   /// the SCEVUnknown components following the Map (Value -> SCEV).
    854   class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
    855   public:
    856     static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
    857                                ValueToSCEVMapTy &Map) {
    858       SCEVParameterRewriter Rewriter(SE, Map);
    859       return Rewriter.visit(Scev);
    860     }
    861 
    862     SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
    863         : SCEVRewriteVisitor(SE), Map(M) {}
    864 
    865     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
    866       auto I = Map.find(Expr->getValue());
    867       if (I == Map.end())
    868         return Expr;
    869       return I->second;
    870     }
    871 
    872   private:
    873     ValueToSCEVMapTy &Map;
    874   };
    875 
    876   using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
    877 
    878   /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
    879   /// the Map (Loop -> SCEV) to all AddRecExprs.
    880   class SCEVLoopAddRecRewriter
    881       : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
    882   public:
    883     SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
    884         : SCEVRewriteVisitor(SE), Map(M) {}
    885 
    886     static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
    887                                ScalarEvolution &SE) {
    888       SCEVLoopAddRecRewriter Rewriter(SE, Map);
    889       return Rewriter.visit(Scev);
    890     }
    891 
    892     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
    893       SmallVector<const SCEV *, 2> Operands;
    894       for (const SCEV *Op : Expr->operands())
    895         Operands.push_back(visit(Op));
    896 
    897       const Loop *L = Expr->getLoop();
    898       const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
    899 
    900       if (0 == Map.count(L))
    901         return Res;
    902 
    903       const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res);
    904       return Rec->evaluateAtIteration(Map[L], SE);
    905     }
    906 
    907   private:
    908     LoopToScevMapT &Map;
    909   };
    910 
    911 } // end namespace llvm
    912 
    913 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
    914