//===- SMTAPI.h -------------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines a SMT generic Solver API, which will be the base class // for every SMT solver specific class. // //===----------------------------------------------------------------------===// #ifndef LLVM_SUPPORT_SMTAPI_H #define LLVM_SUPPORT_SMTAPI_H #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/Support/raw_ostream.h" #include namespace llvm { /// Generic base class for SMT sorts class SMTSort { public: SMTSort() = default; virtual ~SMTSort() = default; /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl(). virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); } /// Returns true if the sort is a floating-point, calls isFloatSortImpl(). virtual bool isFloatSort() const { return isFloatSortImpl(); } /// Returns true if the sort is a boolean, calls isBooleanSortImpl(). virtual bool isBooleanSort() const { return isBooleanSortImpl(); } /// Returns the bitvector size, fails if the sort is not a bitvector /// Calls getBitvectorSortSizeImpl(). virtual unsigned getBitvectorSortSize() const { assert(isBitvectorSort() && "Not a bitvector sort!"); unsigned Size = getBitvectorSortSizeImpl(); assert(Size && "Size is zero!"); return Size; }; /// Returns the floating-point size, fails if the sort is not a floating-point /// Calls getFloatSortSizeImpl(). virtual unsigned getFloatSortSize() const { assert(isFloatSort() && "Not a floating-point sort!"); unsigned Size = getFloatSortSizeImpl(); assert(Size && "Size is zero!"); return Size; }; virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0; bool operator<(const SMTSort &Other) const { llvm::FoldingSetNodeID ID1, ID2; Profile(ID1); Other.Profile(ID2); return ID1 < ID2; } friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) { return LHS.equal_to(RHS); } virtual void print(raw_ostream &OS) const = 0; LLVM_DUMP_METHOD void dump() const; protected: /// Query the SMT solver and returns true if two sorts are equal (same kind /// and bit width). This does not check if the two sorts are the same objects. virtual bool equal_to(SMTSort const &other) const = 0; /// Query the SMT solver and checks if a sort is bitvector. virtual bool isBitvectorSortImpl() const = 0; /// Query the SMT solver and checks if a sort is floating-point. virtual bool isFloatSortImpl() const = 0; /// Query the SMT solver and checks if a sort is boolean. virtual bool isBooleanSortImpl() const = 0; /// Query the SMT solver and returns the sort bit width. virtual unsigned getBitvectorSortSizeImpl() const = 0; /// Query the SMT solver and returns the sort bit width. virtual unsigned getFloatSortSizeImpl() const = 0; }; /// Shared pointer for SMTSorts, used by SMTSolver API. using SMTSortRef = const SMTSort *; /// Generic base class for SMT exprs class SMTExpr { public: SMTExpr() = default; virtual ~SMTExpr() = default; bool operator<(const SMTExpr &Other) const { llvm::FoldingSetNodeID ID1, ID2; Profile(ID1); Other.Profile(ID2); return ID1 < ID2; } virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0; friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) { return LHS.equal_to(RHS); } virtual void print(raw_ostream &OS) const = 0; LLVM_DUMP_METHOD void dump() const; protected: /// Query the SMT solver and returns true if two sorts are equal (same kind /// and bit width). This does not check if the two sorts are the same objects. virtual bool equal_to(SMTExpr const &other) const = 0; }; /// Shared pointer for SMTExprs, used by SMTSolver API. using SMTExprRef = const SMTExpr *; /// Generic base class for SMT Solvers /// /// This class is responsible for wrapping all sorts and expression generation, /// through the mk* methods. It also provides methods to create SMT expressions /// straight from clang's AST, through the from* methods. class SMTSolver { public: SMTSolver() = default; virtual ~SMTSolver() = default; LLVM_DUMP_METHOD void dump() const; // Returns an appropriate floating-point sort for the given bitwidth. SMTSortRef getFloatSort(unsigned BitWidth) { switch (BitWidth) { case 16: return getFloat16Sort(); case 32: return getFloat32Sort(); case 64: return getFloat64Sort(); case 128: return getFloat128Sort(); default:; } llvm_unreachable("Unsupported floating-point bitwidth!"); } // Returns a boolean sort. virtual SMTSortRef getBoolSort() = 0; // Returns an appropriate bitvector sort for the given bitwidth. virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0; // Returns a floating-point sort of width 16 virtual SMTSortRef getFloat16Sort() = 0; // Returns a floating-point sort of width 32 virtual SMTSortRef getFloat32Sort() = 0; // Returns a floating-point sort of width 64 virtual SMTSortRef getFloat64Sort() = 0; // Returns a floating-point sort of width 128 virtual SMTSortRef getFloat128Sort() = 0; // Returns an appropriate sort for the given AST. virtual SMTSortRef getSort(const SMTExprRef &AST) = 0; /// Given a constraint, adds it to the solver virtual void addConstraint(const SMTExprRef &Exp) const = 0; /// Creates a bitvector addition operation virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector subtraction operation virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector multiplication operation virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector signed modulus operation virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector unsigned modulus operation virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector signed division operation virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector unsigned division operation virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector logical shift left operation virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector arithmetic shift right operation virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector logical shift right operation virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector negation operation virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0; /// Creates a bitvector not operation virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0; /// Creates a bitvector xor operation virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector or operation virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector and operation virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector unsigned less-than operation virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector signed less-than operation virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector unsigned greater-than operation virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector signed greater-than operation virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector unsigned less-equal-than operation virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector signed less-equal-than operation virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector unsigned greater-equal-than operation virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a bitvector signed greater-equal-than operation virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a boolean not operation virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0; /// Creates a boolean equality operation virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a boolean and operation virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a boolean or operation virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a boolean ite operation virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T, const SMTExprRef &F) = 0; /// Creates a bitvector sign extension operation virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0; /// Creates a bitvector zero extension operation virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0; /// Creates a bitvector extract operation virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low, const SMTExprRef &Exp) = 0; /// Creates a bitvector concat operation virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a predicate that checks for overflow in a bitvector addition /// operation virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS, bool isSigned) = 0; /// Creates a predicate that checks for underflow in a signed bitvector /// addition operation virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a predicate that checks for overflow in a signed bitvector /// subtraction operation virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a predicate that checks for underflow in a bitvector subtraction /// operation virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS, const SMTExprRef &RHS, bool isSigned) = 0; /// Creates a predicate that checks for overflow in a signed bitvector /// division/modulus operation virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a predicate that checks for overflow in a bitvector negation /// operation virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0; /// Creates a predicate that checks for overflow in a bitvector multiplication /// operation virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS, bool isSigned) = 0; /// Creates a predicate that checks for underflow in a signed bitvector /// multiplication operation virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point negation operation virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0; /// Creates a floating-point isInfinite operation virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0; /// Creates a floating-point isNaN operation virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0; /// Creates a floating-point isNormal operation virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0; /// Creates a floating-point isZero operation virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0; /// Creates a floating-point multiplication operation virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point division operation virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point remainder operation virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point addition operation virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point subtraction operation virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point less-than operation virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point greater-than operation virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point less-than-or-equal operation virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point greater-than-or-equal operation virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point equality operation virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; /// Creates a floating-point conversion from floatint-point to floating-point /// operation virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0; /// Creates a floating-point conversion from signed bitvector to /// floatint-point operation virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0; /// Creates a floating-point conversion from unsigned bitvector to /// floatint-point operation virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0; /// Creates a floating-point conversion from floatint-point to signed /// bitvector operation virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0; /// Creates a floating-point conversion from floatint-point to unsigned /// bitvector operation virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0; /// Creates a new symbol, given a name and a sort virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0; // Returns an appropriate floating-point rounding mode. virtual SMTExprRef getFloatRoundingMode() = 0; // If the a model is available, returns the value of a given bitvector symbol virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth, bool isUnsigned) = 0; // If the a model is available, returns the value of a given boolean symbol virtual bool getBoolean(const SMTExprRef &Exp) = 0; /// Constructs an SMTExprRef from a boolean. virtual SMTExprRef mkBoolean(const bool b) = 0; /// Constructs an SMTExprRef from a finite APFloat. virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0; /// Constructs an SMTExprRef from an APSInt and its bit width virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0; /// Given an expression, extract the value of this operand in the model. virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0; /// Given an expression extract the value of this operand in the model. virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APFloat &Float) = 0; /// Check if the constraints are satisfiable virtual std::optional check() const = 0; /// Push the current solver state virtual void push() = 0; /// Pop the previous solver state virtual void pop(unsigned NumStates = 1) = 0; /// Reset the solver and remove all constraints. virtual void reset() = 0; /// Checks if the solver supports floating-points. virtual bool isFPSupported() = 0; virtual void print(raw_ostream &OS) const = 0; }; /// Shared pointer for SMTSolvers. using SMTSolverRef = std::shared_ptr; /// Convenience method to create and Z3Solver object SMTSolverRef CreateZ3Solver(); } // namespace llvm #endif