1 //== SMTSolver.h ------------------------------------------------*- C++ -*--==//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file defines a SMT generic Solver API, which will be the base class
11 // for every SMT solver specific class.
13 //===----------------------------------------------------------------------===//
15 #ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H
16 #define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H
18 #include "clang/AST/Expr.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/APSIntType.h"
20 #include "clang/StaticAnalyzer/Core/PathSensitive/ConstraintManager.h"
21 #include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h"
22 #include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h"
23 #include "clang/StaticAnalyzer/Core/PathSensitive/SymbolManager.h"
28 /// Generic base class for SMT Solvers
30 /// This class is responsible for wrapping all sorts and expression generation,
31 /// through the mk* methods. It also provides methods to create SMT expressions
32 /// straight from clang's AST, through the from* methods.
35 SMTSolver() = default;
36 virtual ~SMTSolver() = default;
38 LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); }
40 // Returns an appropriate floating-point sort for the given bitwidth.
41 SMTSortRef getFloatSort(unsigned BitWidth) {
44 return getFloat16Sort();
46 return getFloat32Sort();
48 return getFloat64Sort();
50 return getFloat128Sort();
53 llvm_unreachable("Unsupported floating-point bitwidth!");
56 // Returns an appropriate sort, given a QualType and it's bit width.
57 SMTSortRef mkSort(const QualType &Ty, unsigned BitWidth) {
58 if (Ty->isBooleanType())
61 if (Ty->isRealFloatingType())
62 return getFloatSort(BitWidth);
64 return getBitvectorSort(BitWidth);
67 /// Constructs an SMTExprRef from an unary operator.
68 SMTExprRef fromUnOp(const UnaryOperator::Opcode Op, const SMTExprRef &Exp) {
81 llvm_unreachable("Unimplemented opcode");
84 /// Constructs an SMTExprRef from a floating-point unary operator.
85 SMTExprRef fromFloatUnOp(const UnaryOperator::Opcode Op,
86 const SMTExprRef &Exp) {
92 return fromUnOp(Op, Exp);
96 llvm_unreachable("Unimplemented opcode");
99 /// Construct an SMTExprRef from a n-ary binary operator.
100 SMTExprRef fromNBinOp(const BinaryOperator::Opcode Op,
101 const std::vector<SMTExprRef> &ASTs) {
102 assert(!ASTs.empty());
104 if (Op != BO_LAnd && Op != BO_LOr)
105 llvm_unreachable("Unimplemented opcode");
107 SMTExprRef res = ASTs.front();
108 for (std::size_t i = 1; i < ASTs.size(); ++i)
109 res = (Op == BO_LAnd) ? mkAnd(res, ASTs[i]) : mkOr(res, ASTs[i]);
113 /// Construct an SMTExprRef from a binary operator.
114 SMTExprRef fromBinOp(const SMTExprRef &LHS, const BinaryOperator::Opcode Op,
115 const SMTExprRef &RHS, bool isSigned) {
116 assert(*getSort(LHS) == *getSort(RHS) && "AST's must have the same sort!");
119 // Multiplicative operators
121 return mkBVMul(LHS, RHS);
124 return isSigned ? mkBVSDiv(LHS, RHS) : mkBVUDiv(LHS, RHS);
127 return isSigned ? mkBVSRem(LHS, RHS) : mkBVURem(LHS, RHS);
129 // Additive operators
131 return mkBVAdd(LHS, RHS);
134 return mkBVSub(LHS, RHS);
136 // Bitwise shift operators
138 return mkBVShl(LHS, RHS);
141 return isSigned ? mkBVAshr(LHS, RHS) : mkBVLshr(LHS, RHS);
143 // Relational operators
145 return isSigned ? mkBVSlt(LHS, RHS) : mkBVUlt(LHS, RHS);
148 return isSigned ? mkBVSgt(LHS, RHS) : mkBVUgt(LHS, RHS);
151 return isSigned ? mkBVSle(LHS, RHS) : mkBVUle(LHS, RHS);
154 return isSigned ? mkBVSge(LHS, RHS) : mkBVUge(LHS, RHS);
156 // Equality operators
158 return mkEqual(LHS, RHS);
161 return fromUnOp(UO_LNot, fromBinOp(LHS, BO_EQ, RHS, isSigned));
165 return mkBVAnd(LHS, RHS);
168 return mkBVXor(LHS, RHS);
171 return mkBVOr(LHS, RHS);
175 return mkAnd(LHS, RHS);
178 return mkOr(LHS, RHS);
182 llvm_unreachable("Unimplemented opcode");
185 /// Construct an SMTExprRef from a special floating-point binary operator.
186 SMTExprRef fromFloatSpecialBinOp(const SMTExprRef &LHS,
187 const BinaryOperator::Opcode Op,
188 const llvm::APFloat::fltCategory &RHS) {
190 // Equality operators
193 case llvm::APFloat::fcInfinity:
194 return mkFPIsInfinite(LHS);
196 case llvm::APFloat::fcNaN:
197 return mkFPIsNaN(LHS);
199 case llvm::APFloat::fcNormal:
200 return mkFPIsNormal(LHS);
202 case llvm::APFloat::fcZero:
203 return mkFPIsZero(LHS);
208 return fromFloatUnOp(UO_LNot, fromFloatSpecialBinOp(LHS, BO_EQ, RHS));
213 llvm_unreachable("Unimplemented opcode");
216 /// Construct an SMTExprRef from a floating-point binary operator.
217 SMTExprRef fromFloatBinOp(const SMTExprRef &LHS,
218 const BinaryOperator::Opcode Op,
219 const SMTExprRef &RHS) {
220 assert(*getSort(LHS) == *getSort(RHS) && "AST's must have the same sort!");
223 // Multiplicative operators
225 return mkFPMul(LHS, RHS);
228 return mkFPDiv(LHS, RHS);
231 return mkFPRem(LHS, RHS);
233 // Additive operators
235 return mkFPAdd(LHS, RHS);
238 return mkFPSub(LHS, RHS);
240 // Relational operators
242 return mkFPLt(LHS, RHS);
245 return mkFPGt(LHS, RHS);
248 return mkFPLe(LHS, RHS);
251 return mkFPGe(LHS, RHS);
253 // Equality operators
255 return mkFPEqual(LHS, RHS);
258 return fromFloatUnOp(UO_LNot, fromFloatBinOp(LHS, BO_EQ, RHS));
263 return fromBinOp(LHS, Op, RHS, false);
268 llvm_unreachable("Unimplemented opcode");
271 /// Construct an SMTExprRef from a QualType FromTy to a QualType ToTy, and
272 /// their bit widths.
273 SMTExprRef fromCast(const SMTExprRef &Exp, QualType ToTy, uint64_t ToBitWidth,
274 QualType FromTy, uint64_t FromBitWidth) {
275 if ((FromTy->isIntegralOrEnumerationType() &&
276 ToTy->isIntegralOrEnumerationType()) ||
277 (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) ||
278 (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) ||
279 (FromTy->isReferenceType() ^ ToTy->isReferenceType())) {
281 if (FromTy->isBooleanType()) {
282 assert(ToBitWidth > 0 && "BitWidth must be positive!");
283 return mkIte(Exp, mkBitvector(llvm::APSInt("1"), ToBitWidth),
284 mkBitvector(llvm::APSInt("0"), ToBitWidth));
287 if (ToBitWidth > FromBitWidth)
288 return FromTy->isSignedIntegerOrEnumerationType()
289 ? mkBVSignExt(ToBitWidth - FromBitWidth, Exp)
290 : mkBVZeroExt(ToBitWidth - FromBitWidth, Exp);
292 if (ToBitWidth < FromBitWidth)
293 return mkBVExtract(ToBitWidth - 1, 0, Exp);
295 // Both are bitvectors with the same width, ignore the type cast
299 if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) {
300 if (ToBitWidth != FromBitWidth)
301 return mkFPtoFP(Exp, getFloatSort(ToBitWidth));
306 if (FromTy->isIntegralOrEnumerationType() && ToTy->isRealFloatingType()) {
307 SMTSortRef Sort = getFloatSort(ToBitWidth);
308 return FromTy->isSignedIntegerOrEnumerationType() ? mkFPtoSBV(Exp, Sort)
309 : mkFPtoUBV(Exp, Sort);
312 if (FromTy->isRealFloatingType() && ToTy->isIntegralOrEnumerationType())
313 return ToTy->isSignedIntegerOrEnumerationType()
314 ? mkSBVtoFP(Exp, ToBitWidth)
315 : mkUBVtoFP(Exp, ToBitWidth);
317 llvm_unreachable("Unsupported explicit type cast!");
320 // Callback function for doCast parameter on APSInt type.
321 llvm::APSInt castAPSInt(const llvm::APSInt &V, QualType ToTy,
322 uint64_t ToWidth, QualType FromTy,
323 uint64_t FromWidth) {
324 APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType());
325 return TargetType.convert(V);
328 // Generate an SMTExprRef that represents the given symbolic expression.
329 // Sets the hasComparison parameter if the expression has a comparison
331 // Sets the RetTy parameter to the final return type after promotions and
333 SMTExprRef getExpr(ASTContext &Ctx, SymbolRef Sym, QualType *RetTy = nullptr,
334 bool *hasComparison = nullptr) {
336 *hasComparison = false;
339 return getSymExpr(Ctx, Sym, RetTy, hasComparison);
342 // Generate an SMTExprRef that compares the expression to zero.
343 SMTExprRef getZeroExpr(ASTContext &Ctx, const SMTExprRef &Exp, QualType Ty,
346 if (Ty->isRealFloatingType()) {
348 llvm::APFloat::getZero(Ctx.getFloatTypeSemantics(Ty));
349 return fromFloatBinOp(Exp, Assumption ? BO_EQ : BO_NE, fromAPFloat(Zero));
352 if (Ty->isIntegralOrEnumerationType() || Ty->isAnyPointerType() ||
353 Ty->isBlockPointerType() || Ty->isReferenceType()) {
355 // Skip explicit comparison for boolean types
356 bool isSigned = Ty->isSignedIntegerOrEnumerationType();
357 if (Ty->isBooleanType())
358 return Assumption ? fromUnOp(UO_LNot, Exp) : Exp;
360 return fromBinOp(Exp, Assumption ? BO_EQ : BO_NE,
361 fromInt("0", Ctx.getTypeSize(Ty)), isSigned);
364 llvm_unreachable("Unsupported type for zero value!");
367 // Recursive implementation to unpack and generate symbolic expression.
368 // Sets the hasComparison and RetTy parameters. See getExpr().
369 SMTExprRef getSymExpr(ASTContext &Ctx, SymbolRef Sym, QualType *RetTy,
370 bool *hasComparison) {
371 if (const SymbolData *SD = dyn_cast<SymbolData>(Sym)) {
373 *RetTy = Sym->getType();
375 return fromData(SD->getSymbolID(), Sym->getType(),
376 Ctx.getTypeSize(Sym->getType()));
379 if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym)) {
381 *RetTy = Sym->getType();
385 getSymExpr(Ctx, SC->getOperand(), &FromTy, hasComparison);
386 // Casting an expression with a comparison invalidates it. Note that this
387 // must occur after the recursive call above.
388 // e.g. (signed char) (x > 0)
390 *hasComparison = false;
391 return getCastExpr(Ctx, Exp, FromTy, Sym->getType());
394 if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
395 SMTExprRef Exp = getSymBinExpr(Ctx, BSE, hasComparison, RetTy);
396 // Set the hasComparison parameter, in post-order traversal order.
398 *hasComparison = BinaryOperator::isComparisonOp(BSE->getOpcode());
402 llvm_unreachable("Unsupported SymbolRef type!");
405 // Wrapper to generate SMTExprRef from SymbolCast data.
406 SMTExprRef getCastExpr(ASTContext &Ctx, const SMTExprRef &Exp,
407 QualType FromTy, QualType ToTy) {
408 return fromCast(Exp, ToTy, Ctx.getTypeSize(ToTy), FromTy,
409 Ctx.getTypeSize(FromTy));
412 // Wrapper to generate SMTExprRef from BinarySymExpr.
413 // Sets the hasComparison and RetTy parameters. See getSMTExprRef().
414 SMTExprRef getSymBinExpr(ASTContext &Ctx, const BinarySymExpr *BSE,
415 bool *hasComparison, QualType *RetTy) {
417 BinaryOperator::Opcode Op = BSE->getOpcode();
419 if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE)) {
420 SMTExprRef LHS = getSymExpr(Ctx, SIE->getLHS(), <y, hasComparison);
421 llvm::APSInt NewRInt;
422 std::tie(NewRInt, RTy) = fixAPSInt(Ctx, SIE->getRHS());
423 SMTExprRef RHS = fromAPSInt(NewRInt);
424 return getBinExpr(Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
427 if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE)) {
428 llvm::APSInt NewLInt;
429 std::tie(NewLInt, LTy) = fixAPSInt(Ctx, ISE->getLHS());
430 SMTExprRef LHS = fromAPSInt(NewLInt);
431 SMTExprRef RHS = getSymExpr(Ctx, ISE->getRHS(), &RTy, hasComparison);
432 return getBinExpr(Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
435 if (const SymSymExpr *SSM = dyn_cast<SymSymExpr>(BSE)) {
436 SMTExprRef LHS = getSymExpr(Ctx, SSM->getLHS(), <y, hasComparison);
437 SMTExprRef RHS = getSymExpr(Ctx, SSM->getRHS(), &RTy, hasComparison);
438 return getBinExpr(Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
441 llvm_unreachable("Unsupported BinarySymExpr type!");
444 // Wrapper to generate SMTExprRef from unpacked binary symbolic expression.
445 // Sets the RetTy parameter. See getSMTExprRef().
446 SMTExprRef getBinExpr(ASTContext &Ctx, const SMTExprRef &LHS, QualType LTy,
447 BinaryOperator::Opcode Op, const SMTExprRef &RHS,
448 QualType RTy, QualType *RetTy) {
449 SMTExprRef NewLHS = LHS;
450 SMTExprRef NewRHS = RHS;
451 doTypeConversion(Ctx, NewLHS, NewRHS, LTy, RTy);
453 // Update the return type parameter if the output type has changed.
455 // A boolean result can be represented as an integer type in C/C++, but at
456 // this point we only care about the SMT sorts. Set it as a boolean type
457 // to avoid subsequent SMT errors.
458 if (BinaryOperator::isComparisonOp(Op) ||
459 BinaryOperator::isLogicalOp(Op)) {
465 // If the two operands are pointers and the operation is a subtraction,
466 // the result is of type ptrdiff_t, which is signed
467 if (LTy->isAnyPointerType() && RTy->isAnyPointerType() && Op == BO_Sub) {
468 *RetTy = Ctx.getPointerDiffType();
472 return LTy->isRealFloatingType()
473 ? fromFloatBinOp(NewLHS, Op, NewRHS)
474 : fromBinOp(NewLHS, Op, NewRHS,
475 LTy->isSignedIntegerOrEnumerationType());
478 // Wrapper to generate SMTExprRef from a range. If From == To, an equality
479 // will be created instead.
480 SMTExprRef getRangeExpr(ASTContext &Ctx, SymbolRef Sym,
481 const llvm::APSInt &From, const llvm::APSInt &To,
483 // Convert lower bound
485 llvm::APSInt NewFromInt;
486 std::tie(NewFromInt, FromTy) = fixAPSInt(Ctx, From);
487 SMTExprRef FromExp = fromAPSInt(NewFromInt);
491 SMTExprRef Exp = getExpr(Ctx, Sym, &SymTy);
493 // Construct single (in)equality
495 return getBinExpr(Ctx, Exp, SymTy, InRange ? BO_EQ : BO_NE, FromExp,
496 FromTy, /*RetTy=*/nullptr);
499 llvm::APSInt NewToInt;
500 std::tie(NewToInt, ToTy) = fixAPSInt(Ctx, To);
501 SMTExprRef ToExp = fromAPSInt(NewToInt);
502 assert(FromTy == ToTy && "Range values have different types!");
504 // Construct two (in)equalities, and a logical and/or
505 SMTExprRef LHS = getBinExpr(Ctx, Exp, SymTy, InRange ? BO_GE : BO_LT,
506 FromExp, FromTy, /*RetTy=*/nullptr);
508 getBinExpr(Ctx, Exp, SymTy, InRange ? BO_LE : BO_GT, ToExp, ToTy,
511 return fromBinOp(LHS, InRange ? BO_LAnd : BO_LOr, RHS,
512 SymTy->isSignedIntegerOrEnumerationType());
515 // Recover the QualType of an APSInt.
516 // TODO: Refactor to put elsewhere
517 QualType getAPSIntType(ASTContext &Ctx, const llvm::APSInt &Int) {
518 return Ctx.getIntTypeForBitwidth(Int.getBitWidth(), Int.isSigned());
521 // Get the QualTy for the input APSInt, and fix it if it has a bitwidth of 1.
522 std::pair<llvm::APSInt, QualType> fixAPSInt(ASTContext &Ctx,
523 const llvm::APSInt &Int) {
526 // FIXME: This should be a cast from a 1-bit integer type to a boolean type,
527 // but the former is not available in Clang. Instead, extend the APSInt
529 if (Int.getBitWidth() == 1 && getAPSIntType(Ctx, Int).isNull()) {
530 NewInt = Int.extend(Ctx.getTypeSize(Ctx.BoolTy));
534 return std::make_pair(NewInt, getAPSIntType(Ctx, NewInt));
537 // Perform implicit type conversion on binary symbolic expressions.
538 // May modify all input parameters.
539 // TODO: Refactor to use built-in conversion functions
540 void doTypeConversion(ASTContext &Ctx, SMTExprRef &LHS, SMTExprRef &RHS,
541 QualType <y, QualType &RTy) {
542 assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
544 // Perform type conversion
545 if ((LTy->isIntegralOrEnumerationType() &&
546 RTy->isIntegralOrEnumerationType()) &&
547 (LTy->isArithmeticType() && RTy->isArithmeticType())) {
548 doIntTypeConversion<SMTExprRef, &SMTSolver::fromCast>(Ctx, LHS, LTy, RHS,
553 if (LTy->isRealFloatingType() || RTy->isRealFloatingType()) {
554 doFloatTypeConversion<SMTExprRef, &SMTSolver::fromCast>(Ctx, LHS, LTy,
559 if ((LTy->isAnyPointerType() || RTy->isAnyPointerType()) ||
560 (LTy->isBlockPointerType() || RTy->isBlockPointerType()) ||
561 (LTy->isReferenceType() || RTy->isReferenceType())) {
562 // TODO: Refactor to Sema::FindCompositePointerType(), and
563 // Sema::CheckCompareOperands().
565 uint64_t LBitWidth = Ctx.getTypeSize(LTy);
566 uint64_t RBitWidth = Ctx.getTypeSize(RTy);
568 // Cast the non-pointer type to the pointer type.
569 // TODO: Be more strict about this.
570 if ((LTy->isAnyPointerType() ^ RTy->isAnyPointerType()) ||
571 (LTy->isBlockPointerType() ^ RTy->isBlockPointerType()) ||
572 (LTy->isReferenceType() ^ RTy->isReferenceType())) {
573 if (LTy->isNullPtrType() || LTy->isBlockPointerType() ||
574 LTy->isReferenceType()) {
575 LHS = fromCast(LHS, RTy, RBitWidth, LTy, LBitWidth);
578 RHS = fromCast(RHS, LTy, LBitWidth, RTy, RBitWidth);
583 // Cast the void pointer type to the non-void pointer type.
584 // For void types, this assumes that the casted value is equal to the
585 // value of the original pointer, and does not account for alignment
587 if (LTy->isVoidPointerType() ^ RTy->isVoidPointerType()) {
588 assert((Ctx.getTypeSize(LTy) == Ctx.getTypeSize(RTy)) &&
589 "Pointer types have different bitwidths!");
590 if (RTy->isVoidPointerType())
600 // Fallback: for the solver, assume that these types don't really matter
601 if ((LTy.getCanonicalType() == RTy.getCanonicalType()) ||
602 (LTy->isObjCObjectPointerType() && RTy->isObjCObjectPointerType())) {
607 // TODO: Refine behavior for invalid type casts
610 // Perform implicit integer type conversion.
611 // May modify all input parameters.
612 // TODO: Refactor to use Sema::handleIntegerConversion()
613 template <typename T, T (SMTSolver::*doCast)(const T &, QualType, uint64_t,
615 void doIntTypeConversion(ASTContext &Ctx, T &LHS, QualType <y, T &RHS,
618 uint64_t LBitWidth = Ctx.getTypeSize(LTy);
619 uint64_t RBitWidth = Ctx.getTypeSize(RTy);
621 assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
622 // Always perform integer promotion before checking type equality.
623 // Otherwise, e.g. (bool) a + (bool) b could trigger a backend assertion
624 if (LTy->isPromotableIntegerType()) {
625 QualType NewTy = Ctx.getPromotedIntegerType(LTy);
626 uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
627 LHS = (this->*doCast)(LHS, NewTy, NewBitWidth, LTy, LBitWidth);
629 LBitWidth = NewBitWidth;
631 if (RTy->isPromotableIntegerType()) {
632 QualType NewTy = Ctx.getPromotedIntegerType(RTy);
633 uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
634 RHS = (this->*doCast)(RHS, NewTy, NewBitWidth, RTy, RBitWidth);
636 RBitWidth = NewBitWidth;
642 // Perform integer type conversion
643 // Note: Safe to skip updating bitwidth because this must terminate
644 bool isLSignedTy = LTy->isSignedIntegerOrEnumerationType();
645 bool isRSignedTy = RTy->isSignedIntegerOrEnumerationType();
647 int order = Ctx.getIntegerTypeOrder(LTy, RTy);
648 if (isLSignedTy == isRSignedTy) {
649 // Same signedness; use the higher-ranked type
651 RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
654 LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
657 } else if (order != (isLSignedTy ? 1 : -1)) {
658 // The unsigned type has greater than or equal rank to the
659 // signed type, so use the unsigned type
661 RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
664 LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
667 } else if (LBitWidth != RBitWidth) {
668 // The two types are different widths; if we are here, that
669 // means the signed type is larger than the unsigned type, so
670 // use the signed type.
672 RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
675 LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
679 // The signed type is higher-ranked than the unsigned type,
680 // but isn't actually any bigger (like unsigned int and long
681 // on most 32-bit systems). Use the unsigned type corresponding
682 // to the signed type.
684 Ctx.getCorrespondingUnsignedType(isLSignedTy ? LTy : RTy);
685 RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
687 LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
692 // Perform implicit floating-point type conversion.
693 // May modify all input parameters.
694 // TODO: Refactor to use Sema::handleFloatConversion()
695 template <typename T, T (SMTSolver::*doCast)(const T &, QualType, uint64_t,
697 void doFloatTypeConversion(ASTContext &Ctx, T &LHS, QualType <y, T &RHS,
700 uint64_t LBitWidth = Ctx.getTypeSize(LTy);
701 uint64_t RBitWidth = Ctx.getTypeSize(RTy);
703 // Perform float-point type promotion
704 if (!LTy->isRealFloatingType()) {
705 LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
707 LBitWidth = RBitWidth;
709 if (!RTy->isRealFloatingType()) {
710 RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
712 RBitWidth = LBitWidth;
718 // If we have two real floating types, convert the smaller operand to the
720 // Note: Safe to skip updating bitwidth because this must terminate
721 int order = Ctx.getFloatingTypeOrder(LTy, RTy);
723 RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
725 } else if (order == 0) {
726 LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
729 llvm_unreachable("Unsupported floating-point type cast!");
733 // Returns a boolean sort.
734 virtual SMTSortRef getBoolSort() = 0;
736 // Returns an appropriate bitvector sort for the given bitwidth.
737 virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
739 // Returns a floating-point sort of width 16
740 virtual SMTSortRef getFloat16Sort() = 0;
742 // Returns a floating-point sort of width 32
743 virtual SMTSortRef getFloat32Sort() = 0;
745 // Returns a floating-point sort of width 64
746 virtual SMTSortRef getFloat64Sort() = 0;
748 // Returns a floating-point sort of width 128
749 virtual SMTSortRef getFloat128Sort() = 0;
751 // Returns an appropriate sort for the given AST.
752 virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
754 // Returns a new SMTExprRef from an SMTExpr
755 virtual SMTExprRef newExprRef(const SMTExpr &E) const = 0;
757 /// Given a constraint, adds it to the solver
758 virtual void addConstraint(const SMTExprRef &Exp) const = 0;
760 /// Creates a bitvector addition operation
761 virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
763 /// Creates a bitvector subtraction operation
764 virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
766 /// Creates a bitvector multiplication operation
767 virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
769 /// Creates a bitvector signed modulus operation
770 virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
772 /// Creates a bitvector unsigned modulus operation
773 virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
775 /// Creates a bitvector signed division operation
776 virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
778 /// Creates a bitvector unsigned division operation
779 virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
781 /// Creates a bitvector logical shift left operation
782 virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
784 /// Creates a bitvector arithmetic shift right operation
785 virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
787 /// Creates a bitvector logical shift right operation
788 virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
790 /// Creates a bitvector negation operation
791 virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
793 /// Creates a bitvector not operation
794 virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
796 /// Creates a bitvector xor operation
797 virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
799 /// Creates a bitvector or operation
800 virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
802 /// Creates a bitvector and operation
803 virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
805 /// Creates a bitvector unsigned less-than operation
806 virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
808 /// Creates a bitvector signed less-than operation
809 virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
811 /// Creates a bitvector unsigned greater-than operation
812 virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
814 /// Creates a bitvector signed greater-than operation
815 virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
817 /// Creates a bitvector unsigned less-equal-than operation
818 virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
820 /// Creates a bitvector signed less-equal-than operation
821 virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
823 /// Creates a bitvector unsigned greater-equal-than operation
824 virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
826 /// Creates a bitvector signed greater-equal-than operation
827 virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
829 /// Creates a boolean not operation
830 virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
832 /// Creates a boolean equality operation
833 virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
835 /// Creates a boolean and operation
836 virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
838 /// Creates a boolean or operation
839 virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
841 /// Creates a boolean ite operation
842 virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
843 const SMTExprRef &F) = 0;
845 /// Creates a bitvector sign extension operation
846 virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
848 /// Creates a bitvector zero extension operation
849 virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
851 /// Creates a bitvector extract operation
852 virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
853 const SMTExprRef &Exp) = 0;
855 /// Creates a bitvector concat operation
856 virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
857 const SMTExprRef &RHS) = 0;
859 /// Creates a floating-point negation operation
860 virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
862 /// Creates a floating-point isInfinite operation
863 virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
865 /// Creates a floating-point isNaN operation
866 virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
868 /// Creates a floating-point isNormal operation
869 virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
871 /// Creates a floating-point isZero operation
872 virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
874 /// Creates a floating-point multiplication operation
875 virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
877 /// Creates a floating-point division operation
878 virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
880 /// Creates a floating-point remainder operation
881 virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
883 /// Creates a floating-point addition operation
884 virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
886 /// Creates a floating-point subtraction operation
887 virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
889 /// Creates a floating-point less-than operation
890 virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
892 /// Creates a floating-point greater-than operation
893 virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
895 /// Creates a floating-point less-than-or-equal operation
896 virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
898 /// Creates a floating-point greater-than-or-equal operation
899 virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
901 /// Creates a floating-point equality operation
902 virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
903 const SMTExprRef &RHS) = 0;
905 /// Creates a floating-point conversion from floatint-point to floating-point
907 virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
909 /// Creates a floating-point conversion from floatint-point to signed
910 /// bitvector operation
911 virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From,
912 const SMTSortRef &To) = 0;
914 /// Creates a floating-point conversion from floatint-point to unsigned
915 /// bitvector operation
916 virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From,
917 const SMTSortRef &To) = 0;
919 /// Creates a floating-point conversion from signed bitvector to
920 /// floatint-point operation
921 virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From, unsigned ToWidth) = 0;
923 /// Creates a floating-point conversion from unsigned bitvector to
924 /// floatint-point operation
925 virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From, unsigned ToWidth) = 0;
927 /// Creates a new symbol, given a name and a sort
928 virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
930 // Returns an appropriate floating-point rounding mode.
931 virtual SMTExprRef getFloatRoundingMode() = 0;
933 // If the a model is available, returns the value of a given bitvector symbol
934 virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
935 bool isUnsigned) = 0;
937 // If the a model is available, returns the value of a given boolean symbol
938 virtual bool getBoolean(const SMTExprRef &Exp) = 0;
940 /// Constructs an SMTExprRef from a boolean.
941 virtual SMTExprRef mkBoolean(const bool b) = 0;
943 /// Constructs an SMTExprRef from a finite APFloat.
944 virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
946 /// Constructs an SMTExprRef from an APSInt and its bit width
947 virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
949 /// Given an expression, extract the value of this operand in the model.
950 virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
952 /// Given an expression extract the value of this operand in the model.
953 virtual bool getInterpretation(const SMTExprRef &Exp,
954 llvm::APFloat &Float) = 0;
956 /// Construct an SMTExprRef value from a boolean.
957 virtual SMTExprRef fromBoolean(const bool Bool) = 0;
959 /// Construct an SMTExprRef value from a finite APFloat.
960 virtual SMTExprRef fromAPFloat(const llvm::APFloat &Float) = 0;
962 /// Construct an SMTExprRef value from an APSInt.
963 virtual SMTExprRef fromAPSInt(const llvm::APSInt &Int) = 0;
965 /// Construct an SMTExprRef value from an integer.
966 virtual SMTExprRef fromInt(const char *Int, uint64_t BitWidth) = 0;
968 /// Construct an SMTExprRef from a SymbolData.
969 virtual SMTExprRef fromData(const SymbolID ID, const QualType &Ty,
970 uint64_t BitWidth) = 0;
972 /// Check if the constraints are satisfiable
973 virtual ConditionTruthVal check() const = 0;
975 /// Push the current solver state
976 virtual void push() = 0;
978 /// Pop the previous solver state
979 virtual void pop(unsigned NumStates = 1) = 0;
981 /// Reset the solver and remove all constraints.
982 virtual void reset() const = 0;
984 virtual void print(raw_ostream &OS) const = 0;
987 /// Shared pointer for SMTSolvers.
988 using SMTSolverRef = std::shared_ptr<SMTSolver>;
990 /// Convenience method to create and Z3Solver object
991 std::unique_ptr<SMTSolver> CreateZ3Solver();