Home | History | Annotate | Line # | Download | only in Support
      1 //===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 //  This file defines a SMT generic Solver API, which will be the base class
     10 //  for every SMT solver specific class.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_SUPPORT_SMTAPI_H
     15 #define LLVM_SUPPORT_SMTAPI_H
     16 
     17 #include "llvm/ADT/APFloat.h"
     18 #include "llvm/ADT/APSInt.h"
     19 #include "llvm/ADT/FoldingSet.h"
     20 #include "llvm/Support/raw_ostream.h"
     21 #include <memory>
     22 
     23 namespace llvm {
     24 
     25 /// Generic base class for SMT sorts
     26 class SMTSort {
     27 public:
     28   SMTSort() = default;
     29   virtual ~SMTSort() = default;
     30 
     31   /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
     32   virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
     33 
     34   /// Returns true if the sort is a floating-point, calls isFloatSortImpl().
     35   virtual bool isFloatSort() const { return isFloatSortImpl(); }
     36 
     37   /// Returns true if the sort is a boolean, calls isBooleanSortImpl().
     38   virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
     39 
     40   /// Returns the bitvector size, fails if the sort is not a bitvector
     41   /// Calls getBitvectorSortSizeImpl().
     42   virtual unsigned getBitvectorSortSize() const {
     43     assert(isBitvectorSort() && "Not a bitvector sort!");
     44     unsigned Size = getBitvectorSortSizeImpl();
     45     assert(Size && "Size is zero!");
     46     return Size;
     47   };
     48 
     49   /// Returns the floating-point size, fails if the sort is not a floating-point
     50   /// Calls getFloatSortSizeImpl().
     51   virtual unsigned getFloatSortSize() const {
     52     assert(isFloatSort() && "Not a floating-point sort!");
     53     unsigned Size = getFloatSortSizeImpl();
     54     assert(Size && "Size is zero!");
     55     return Size;
     56   };
     57 
     58   virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
     59 
     60   bool operator<(const SMTSort &Other) const {
     61     llvm::FoldingSetNodeID ID1, ID2;
     62     Profile(ID1);
     63     Other.Profile(ID2);
     64     return ID1 < ID2;
     65   }
     66 
     67   friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) {
     68     return LHS.equal_to(RHS);
     69   }
     70 
     71   virtual void print(raw_ostream &OS) const = 0;
     72 
     73   LLVM_DUMP_METHOD void dump() const;
     74 
     75 protected:
     76   /// Query the SMT solver and returns true if two sorts are equal (same kind
     77   /// and bit width). This does not check if the two sorts are the same objects.
     78   virtual bool equal_to(SMTSort const &other) const = 0;
     79 
     80   /// Query the SMT solver and checks if a sort is bitvector.
     81   virtual bool isBitvectorSortImpl() const = 0;
     82 
     83   /// Query the SMT solver and checks if a sort is floating-point.
     84   virtual bool isFloatSortImpl() const = 0;
     85 
     86   /// Query the SMT solver and checks if a sort is boolean.
     87   virtual bool isBooleanSortImpl() const = 0;
     88 
     89   /// Query the SMT solver and returns the sort bit width.
     90   virtual unsigned getBitvectorSortSizeImpl() const = 0;
     91 
     92   /// Query the SMT solver and returns the sort bit width.
     93   virtual unsigned getFloatSortSizeImpl() const = 0;
     94 };
     95 
     96 /// Shared pointer for SMTSorts, used by SMTSolver API.
     97 using SMTSortRef = const SMTSort *;
     98 
     99 /// Generic base class for SMT exprs
    100 class SMTExpr {
    101 public:
    102   SMTExpr() = default;
    103   virtual ~SMTExpr() = default;
    104 
    105   bool operator<(const SMTExpr &Other) const {
    106     llvm::FoldingSetNodeID ID1, ID2;
    107     Profile(ID1);
    108     Other.Profile(ID2);
    109     return ID1 < ID2;
    110   }
    111 
    112   virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
    113 
    114   friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) {
    115     return LHS.equal_to(RHS);
    116   }
    117 
    118   virtual void print(raw_ostream &OS) const = 0;
    119 
    120   LLVM_DUMP_METHOD void dump() const;
    121 
    122 protected:
    123   /// Query the SMT solver and returns true if two sorts are equal (same kind
    124   /// and bit width). This does not check if the two sorts are the same objects.
    125   virtual bool equal_to(SMTExpr const &other) const = 0;
    126 };
    127 
    128 /// Shared pointer for SMTExprs, used by SMTSolver API.
    129 using SMTExprRef = const SMTExpr *;
    130 
    131 /// Generic base class for SMT Solvers
    132 ///
    133 /// This class is responsible for wrapping all sorts and expression generation,
    134 /// through the mk* methods. It also provides methods to create SMT expressions
    135 /// straight from clang's AST, through the from* methods.
    136 class SMTSolver {
    137 public:
    138   SMTSolver() = default;
    139   virtual ~SMTSolver() = default;
    140 
    141   LLVM_DUMP_METHOD void dump() const;
    142 
    143   // Returns an appropriate floating-point sort for the given bitwidth.
    144   SMTSortRef getFloatSort(unsigned BitWidth) {
    145     switch (BitWidth) {
    146     case 16:
    147       return getFloat16Sort();
    148     case 32:
    149       return getFloat32Sort();
    150     case 64:
    151       return getFloat64Sort();
    152     case 128:
    153       return getFloat128Sort();
    154     default:;
    155     }
    156     llvm_unreachable("Unsupported floating-point bitwidth!");
    157   }
    158 
    159   // Returns a boolean sort.
    160   virtual SMTSortRef getBoolSort() = 0;
    161 
    162   // Returns an appropriate bitvector sort for the given bitwidth.
    163   virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
    164 
    165   // Returns a floating-point sort of width 16
    166   virtual SMTSortRef getFloat16Sort() = 0;
    167 
    168   // Returns a floating-point sort of width 32
    169   virtual SMTSortRef getFloat32Sort() = 0;
    170 
    171   // Returns a floating-point sort of width 64
    172   virtual SMTSortRef getFloat64Sort() = 0;
    173 
    174   // Returns a floating-point sort of width 128
    175   virtual SMTSortRef getFloat128Sort() = 0;
    176 
    177   // Returns an appropriate sort for the given AST.
    178   virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
    179 
    180   /// Given a constraint, adds it to the solver
    181   virtual void addConstraint(const SMTExprRef &Exp) const = 0;
    182 
    183   /// Creates a bitvector addition operation
    184   virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    185 
    186   /// Creates a bitvector subtraction operation
    187   virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    188 
    189   /// Creates a bitvector multiplication operation
    190   virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    191 
    192   /// Creates a bitvector signed modulus operation
    193   virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    194 
    195   /// Creates a bitvector unsigned modulus operation
    196   virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    197 
    198   /// Creates a bitvector signed division operation
    199   virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    200 
    201   /// Creates a bitvector unsigned division operation
    202   virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    203 
    204   /// Creates a bitvector logical shift left operation
    205   virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    206 
    207   /// Creates a bitvector arithmetic shift right operation
    208   virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    209 
    210   /// Creates a bitvector logical shift right operation
    211   virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    212 
    213   /// Creates a bitvector negation operation
    214   virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
    215 
    216   /// Creates a bitvector not operation
    217   virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
    218 
    219   /// Creates a bitvector xor operation
    220   virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    221 
    222   /// Creates a bitvector or operation
    223   virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    224 
    225   /// Creates a bitvector and operation
    226   virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    227 
    228   /// Creates a bitvector unsigned less-than operation
    229   virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    230 
    231   /// Creates a bitvector signed less-than operation
    232   virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    233 
    234   /// Creates a bitvector unsigned greater-than operation
    235   virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    236 
    237   /// Creates a bitvector signed greater-than operation
    238   virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    239 
    240   /// Creates a bitvector unsigned less-equal-than operation
    241   virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    242 
    243   /// Creates a bitvector signed less-equal-than operation
    244   virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    245 
    246   /// Creates a bitvector unsigned greater-equal-than operation
    247   virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    248 
    249   /// Creates a bitvector signed greater-equal-than operation
    250   virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    251 
    252   /// Creates a boolean not operation
    253   virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
    254 
    255   /// Creates a boolean equality operation
    256   virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    257 
    258   /// Creates a boolean and operation
    259   virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    260 
    261   /// Creates a boolean or operation
    262   virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    263 
    264   /// Creates a boolean ite operation
    265   virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
    266                            const SMTExprRef &F) = 0;
    267 
    268   /// Creates a bitvector sign extension operation
    269   virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
    270 
    271   /// Creates a bitvector zero extension operation
    272   virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
    273 
    274   /// Creates a bitvector extract operation
    275   virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
    276                                  const SMTExprRef &Exp) = 0;
    277 
    278   /// Creates a bitvector concat operation
    279   virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
    280                                 const SMTExprRef &RHS) = 0;
    281 
    282   /// Creates a predicate that checks for overflow in a bitvector addition
    283   /// operation
    284   virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS,
    285                                        const SMTExprRef &RHS,
    286                                        bool isSigned) = 0;
    287 
    288   /// Creates a predicate that checks for underflow in a signed bitvector
    289   /// addition operation
    290   virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
    291                                         const SMTExprRef &RHS) = 0;
    292 
    293   /// Creates a predicate that checks for overflow in a signed bitvector
    294   /// subtraction operation
    295   virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
    296                                        const SMTExprRef &RHS) = 0;
    297 
    298   /// Creates a predicate that checks for underflow in a bitvector subtraction
    299   /// operation
    300   virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS,
    301                                         const SMTExprRef &RHS,
    302                                         bool isSigned) = 0;
    303 
    304   /// Creates a predicate that checks for overflow in a signed bitvector
    305   /// division/modulus operation
    306   virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
    307                                         const SMTExprRef &RHS) = 0;
    308 
    309   /// Creates a predicate that checks for overflow in a bitvector negation
    310   /// operation
    311   virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0;
    312 
    313   /// Creates a predicate that checks for overflow in a bitvector multiplication
    314   /// operation
    315   virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS,
    316                                        const SMTExprRef &RHS,
    317                                        bool isSigned) = 0;
    318 
    319   /// Creates a predicate that checks for underflow in a signed bitvector
    320   /// multiplication operation
    321   virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
    322                                         const SMTExprRef &RHS) = 0;
    323 
    324   /// Creates a floating-point negation operation
    325   virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
    326 
    327   /// Creates a floating-point isInfinite operation
    328   virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
    329 
    330   /// Creates a floating-point isNaN operation
    331   virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
    332 
    333   /// Creates a floating-point isNormal operation
    334   virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
    335 
    336   /// Creates a floating-point isZero operation
    337   virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
    338 
    339   /// Creates a floating-point multiplication operation
    340   virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    341 
    342   /// Creates a floating-point division operation
    343   virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    344 
    345   /// Creates a floating-point remainder operation
    346   virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    347 
    348   /// Creates a floating-point addition operation
    349   virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    350 
    351   /// Creates a floating-point subtraction operation
    352   virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    353 
    354   /// Creates a floating-point less-than operation
    355   virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    356 
    357   /// Creates a floating-point greater-than operation
    358   virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    359 
    360   /// Creates a floating-point less-than-or-equal operation
    361   virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    362 
    363   /// Creates a floating-point greater-than-or-equal operation
    364   virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
    365 
    366   /// Creates a floating-point equality operation
    367   virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
    368                                const SMTExprRef &RHS) = 0;
    369 
    370   /// Creates a floating-point conversion from floatint-point to floating-point
    371   /// operation
    372   virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
    373 
    374   /// Creates a floating-point conversion from signed bitvector to
    375   /// floatint-point operation
    376   virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From,
    377                                const SMTSortRef &To) = 0;
    378 
    379   /// Creates a floating-point conversion from unsigned bitvector to
    380   /// floatint-point operation
    381   virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From,
    382                                const SMTSortRef &To) = 0;
    383 
    384   /// Creates a floating-point conversion from floatint-point to signed
    385   /// bitvector operation
    386   virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0;
    387 
    388   /// Creates a floating-point conversion from floatint-point to unsigned
    389   /// bitvector operation
    390   virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0;
    391 
    392   /// Creates a new symbol, given a name and a sort
    393   virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
    394 
    395   // Returns an appropriate floating-point rounding mode.
    396   virtual SMTExprRef getFloatRoundingMode() = 0;
    397 
    398   // If the a model is available, returns the value of a given bitvector symbol
    399   virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
    400                                     bool isUnsigned) = 0;
    401 
    402   // If the a model is available, returns the value of a given boolean symbol
    403   virtual bool getBoolean(const SMTExprRef &Exp) = 0;
    404 
    405   /// Constructs an SMTExprRef from a boolean.
    406   virtual SMTExprRef mkBoolean(const bool b) = 0;
    407 
    408   /// Constructs an SMTExprRef from a finite APFloat.
    409   virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
    410 
    411   /// Constructs an SMTExprRef from an APSInt and its bit width
    412   virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
    413 
    414   /// Given an expression, extract the value of this operand in the model.
    415   virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
    416 
    417   /// Given an expression extract the value of this operand in the model.
    418   virtual bool getInterpretation(const SMTExprRef &Exp,
    419                                  llvm::APFloat &Float) = 0;
    420 
    421   /// Check if the constraints are satisfiable
    422   virtual Optional<bool> check() const = 0;
    423 
    424   /// Push the current solver state
    425   virtual void push() = 0;
    426 
    427   /// Pop the previous solver state
    428   virtual void pop(unsigned NumStates = 1) = 0;
    429 
    430   /// Reset the solver and remove all constraints.
    431   virtual void reset() = 0;
    432 
    433   /// Checks if the solver supports floating-points.
    434   virtual bool isFPSupported() = 0;
    435 
    436   virtual void print(raw_ostream &OS) const = 0;
    437 };
    438 
    439 /// Shared pointer for SMTSolvers.
    440 using SMTSolverRef = std::shared_ptr<SMTSolver>;
    441 
    442 /// Convenience method to create and Z3Solver object
    443 SMTSolverRef CreateZ3Solver();
    444 
    445 } // namespace llvm
    446 
    447 #endif
    448