]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/tools/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h
Merge clang 7.0.1 and several follow-up changes
[FreeBSD/FreeBSD.git] / contrib / llvm / tools / clang / include / clang / StaticAnalyzer / Core / PathSensitive / SMTSolver.h
1 //== SMTSolver.h ------------------------------------------------*- C++ -*--==//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //  This file defines a SMT generic Solver API, which will be the base class
11 //  for every SMT solver specific class.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H
16 #define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H
17
18 #include "clang/AST/Expr.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/APSIntType.h"
20 #include "clang/StaticAnalyzer/Core/PathSensitive/ConstraintManager.h"
21 #include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h"
22 #include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h"
23 #include "clang/StaticAnalyzer/Core/PathSensitive/SymbolManager.h"
24
25 namespace clang {
26 namespace ento {
27
28 /// Generic base class for SMT Solvers
29 ///
30 /// This class is responsible for wrapping all sorts and expression generation,
31 /// through the mk* methods. It also provides methods to create SMT expressions
32 /// straight from clang's AST, through the from* methods.
33 class SMTSolver {
34 public:
35   SMTSolver() = default;
36   virtual ~SMTSolver() = default;
37
38   LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); }
39
40   // Returns an appropriate floating-point sort for the given bitwidth.
41   SMTSortRef getFloatSort(unsigned BitWidth) {
42     switch (BitWidth) {
43     case 16:
44       return getFloat16Sort();
45     case 32:
46       return getFloat32Sort();
47     case 64:
48       return getFloat64Sort();
49     case 128:
50       return getFloat128Sort();
51     default:;
52     }
53     llvm_unreachable("Unsupported floating-point bitwidth!");
54   }
55
56   // Returns an appropriate sort, given a QualType and it's bit width.
57   SMTSortRef mkSort(const QualType &Ty, unsigned BitWidth) {
58     if (Ty->isBooleanType())
59       return getBoolSort();
60
61     if (Ty->isRealFloatingType())
62       return getFloatSort(BitWidth);
63
64     return getBitvectorSort(BitWidth);
65   }
66
67   /// Constructs an SMTExprRef from an unary operator.
68   SMTExprRef fromUnOp(const UnaryOperator::Opcode Op, const SMTExprRef &Exp) {
69     switch (Op) {
70     case UO_Minus:
71       return mkBVNeg(Exp);
72
73     case UO_Not:
74       return mkBVNot(Exp);
75
76     case UO_LNot:
77       return mkNot(Exp);
78
79     default:;
80     }
81     llvm_unreachable("Unimplemented opcode");
82   }
83
84   /// Constructs an SMTExprRef from a floating-point unary operator.
85   SMTExprRef fromFloatUnOp(const UnaryOperator::Opcode Op,
86                            const SMTExprRef &Exp) {
87     switch (Op) {
88     case UO_Minus:
89       return mkFPNeg(Exp);
90
91     case UO_LNot:
92       return fromUnOp(Op, Exp);
93
94     default:;
95     }
96     llvm_unreachable("Unimplemented opcode");
97   }
98
99   /// Construct an SMTExprRef from a n-ary binary operator.
100   SMTExprRef fromNBinOp(const BinaryOperator::Opcode Op,
101                         const std::vector<SMTExprRef> &ASTs) {
102     assert(!ASTs.empty());
103
104     if (Op != BO_LAnd && Op != BO_LOr)
105       llvm_unreachable("Unimplemented opcode");
106
107     SMTExprRef res = ASTs.front();
108     for (std::size_t i = 1; i < ASTs.size(); ++i)
109       res = (Op == BO_LAnd) ? mkAnd(res, ASTs[i]) : mkOr(res, ASTs[i]);
110     return res;
111   }
112
113   /// Construct an SMTExprRef from a binary operator.
114   SMTExprRef fromBinOp(const SMTExprRef &LHS, const BinaryOperator::Opcode Op,
115                        const SMTExprRef &RHS, bool isSigned) {
116     assert(*getSort(LHS) == *getSort(RHS) && "AST's must have the same sort!");
117
118     switch (Op) {
119     // Multiplicative operators
120     case BO_Mul:
121       return mkBVMul(LHS, RHS);
122
123     case BO_Div:
124       return isSigned ? mkBVSDiv(LHS, RHS) : mkBVUDiv(LHS, RHS);
125
126     case BO_Rem:
127       return isSigned ? mkBVSRem(LHS, RHS) : mkBVURem(LHS, RHS);
128
129     // Additive operators
130     case BO_Add:
131       return mkBVAdd(LHS, RHS);
132
133     case BO_Sub:
134       return mkBVSub(LHS, RHS);
135
136     // Bitwise shift operators
137     case BO_Shl:
138       return mkBVShl(LHS, RHS);
139
140     case BO_Shr:
141       return isSigned ? mkBVAshr(LHS, RHS) : mkBVLshr(LHS, RHS);
142
143     // Relational operators
144     case BO_LT:
145       return isSigned ? mkBVSlt(LHS, RHS) : mkBVUlt(LHS, RHS);
146
147     case BO_GT:
148       return isSigned ? mkBVSgt(LHS, RHS) : mkBVUgt(LHS, RHS);
149
150     case BO_LE:
151       return isSigned ? mkBVSle(LHS, RHS) : mkBVUle(LHS, RHS);
152
153     case BO_GE:
154       return isSigned ? mkBVSge(LHS, RHS) : mkBVUge(LHS, RHS);
155
156     // Equality operators
157     case BO_EQ:
158       return mkEqual(LHS, RHS);
159
160     case BO_NE:
161       return fromUnOp(UO_LNot, fromBinOp(LHS, BO_EQ, RHS, isSigned));
162
163     // Bitwise operators
164     case BO_And:
165       return mkBVAnd(LHS, RHS);
166
167     case BO_Xor:
168       return mkBVXor(LHS, RHS);
169
170     case BO_Or:
171       return mkBVOr(LHS, RHS);
172
173     // Logical operators
174     case BO_LAnd:
175       return mkAnd(LHS, RHS);
176
177     case BO_LOr:
178       return mkOr(LHS, RHS);
179
180     default:;
181     }
182     llvm_unreachable("Unimplemented opcode");
183   }
184
185   /// Construct an SMTExprRef from a special floating-point binary operator.
186   SMTExprRef fromFloatSpecialBinOp(const SMTExprRef &LHS,
187                                    const BinaryOperator::Opcode Op,
188                                    const llvm::APFloat::fltCategory &RHS) {
189     switch (Op) {
190     // Equality operators
191     case BO_EQ:
192       switch (RHS) {
193       case llvm::APFloat::fcInfinity:
194         return mkFPIsInfinite(LHS);
195
196       case llvm::APFloat::fcNaN:
197         return mkFPIsNaN(LHS);
198
199       case llvm::APFloat::fcNormal:
200         return mkFPIsNormal(LHS);
201
202       case llvm::APFloat::fcZero:
203         return mkFPIsZero(LHS);
204       }
205       break;
206
207     case BO_NE:
208       return fromFloatUnOp(UO_LNot, fromFloatSpecialBinOp(LHS, BO_EQ, RHS));
209
210     default:;
211     }
212
213     llvm_unreachable("Unimplemented opcode");
214   }
215
216   /// Construct an SMTExprRef from a floating-point binary operator.
217   SMTExprRef fromFloatBinOp(const SMTExprRef &LHS,
218                             const BinaryOperator::Opcode Op,
219                             const SMTExprRef &RHS) {
220     assert(*getSort(LHS) == *getSort(RHS) && "AST's must have the same sort!");
221
222     switch (Op) {
223     // Multiplicative operators
224     case BO_Mul:
225       return mkFPMul(LHS, RHS);
226
227     case BO_Div:
228       return mkFPDiv(LHS, RHS);
229
230     case BO_Rem:
231       return mkFPRem(LHS, RHS);
232
233       // Additive operators
234     case BO_Add:
235       return mkFPAdd(LHS, RHS);
236
237     case BO_Sub:
238       return mkFPSub(LHS, RHS);
239
240       // Relational operators
241     case BO_LT:
242       return mkFPLt(LHS, RHS);
243
244     case BO_GT:
245       return mkFPGt(LHS, RHS);
246
247     case BO_LE:
248       return mkFPLe(LHS, RHS);
249
250     case BO_GE:
251       return mkFPGe(LHS, RHS);
252
253       // Equality operators
254     case BO_EQ:
255       return mkFPEqual(LHS, RHS);
256
257     case BO_NE:
258       return fromFloatUnOp(UO_LNot, fromFloatBinOp(LHS, BO_EQ, RHS));
259
260       // Logical operators
261     case BO_LAnd:
262     case BO_LOr:
263       return fromBinOp(LHS, Op, RHS, false);
264
265     default:;
266     }
267
268     llvm_unreachable("Unimplemented opcode");
269   }
270
271   /// Construct an SMTExprRef from a QualType FromTy to a QualType ToTy, and
272   /// their bit widths.
273   SMTExprRef fromCast(const SMTExprRef &Exp, QualType ToTy, uint64_t ToBitWidth,
274                       QualType FromTy, uint64_t FromBitWidth) {
275     if ((FromTy->isIntegralOrEnumerationType() &&
276          ToTy->isIntegralOrEnumerationType()) ||
277         (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) ||
278         (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) ||
279         (FromTy->isReferenceType() ^ ToTy->isReferenceType())) {
280
281       if (FromTy->isBooleanType()) {
282         assert(ToBitWidth > 0 && "BitWidth must be positive!");
283         return mkIte(Exp, mkBitvector(llvm::APSInt("1"), ToBitWidth),
284                      mkBitvector(llvm::APSInt("0"), ToBitWidth));
285       }
286
287       if (ToBitWidth > FromBitWidth)
288         return FromTy->isSignedIntegerOrEnumerationType()
289                    ? mkBVSignExt(ToBitWidth - FromBitWidth, Exp)
290                    : mkBVZeroExt(ToBitWidth - FromBitWidth, Exp);
291
292       if (ToBitWidth < FromBitWidth)
293         return mkBVExtract(ToBitWidth - 1, 0, Exp);
294
295       // Both are bitvectors with the same width, ignore the type cast
296       return Exp;
297     }
298
299     if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) {
300       if (ToBitWidth != FromBitWidth)
301         return mkFPtoFP(Exp, getFloatSort(ToBitWidth));
302
303       return Exp;
304     }
305
306     if (FromTy->isIntegralOrEnumerationType() && ToTy->isRealFloatingType()) {
307       SMTSortRef Sort = getFloatSort(ToBitWidth);
308       return FromTy->isSignedIntegerOrEnumerationType() ? mkFPtoSBV(Exp, Sort)
309                                                         : mkFPtoUBV(Exp, Sort);
310     }
311
312     if (FromTy->isRealFloatingType() && ToTy->isIntegralOrEnumerationType())
313       return ToTy->isSignedIntegerOrEnumerationType()
314                  ? mkSBVtoFP(Exp, ToBitWidth)
315                  : mkUBVtoFP(Exp, ToBitWidth);
316
317     llvm_unreachable("Unsupported explicit type cast!");
318   }
319
320   // Callback function for doCast parameter on APSInt type.
321   llvm::APSInt castAPSInt(const llvm::APSInt &V, QualType ToTy,
322                           uint64_t ToWidth, QualType FromTy,
323                           uint64_t FromWidth) {
324     APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType());
325     return TargetType.convert(V);
326   }
327
328   // Generate an SMTExprRef that represents the given symbolic expression.
329   // Sets the hasComparison parameter if the expression has a comparison
330   // operator.
331   // Sets the RetTy parameter to the final return type after promotions and
332   // casts.
333   SMTExprRef getExpr(ASTContext &Ctx, SymbolRef Sym, QualType *RetTy = nullptr,
334                      bool *hasComparison = nullptr) {
335     if (hasComparison) {
336       *hasComparison = false;
337     }
338
339     return getSymExpr(Ctx, Sym, RetTy, hasComparison);
340   }
341
342   // Generate an SMTExprRef that compares the expression to zero.
343   SMTExprRef getZeroExpr(ASTContext &Ctx, const SMTExprRef &Exp, QualType Ty,
344                          bool Assumption) {
345
346     if (Ty->isRealFloatingType()) {
347       llvm::APFloat Zero =
348           llvm::APFloat::getZero(Ctx.getFloatTypeSemantics(Ty));
349       return fromFloatBinOp(Exp, Assumption ? BO_EQ : BO_NE, fromAPFloat(Zero));
350     }
351
352     if (Ty->isIntegralOrEnumerationType() || Ty->isAnyPointerType() ||
353         Ty->isBlockPointerType() || Ty->isReferenceType()) {
354
355       // Skip explicit comparison for boolean types
356       bool isSigned = Ty->isSignedIntegerOrEnumerationType();
357       if (Ty->isBooleanType())
358         return Assumption ? fromUnOp(UO_LNot, Exp) : Exp;
359
360       return fromBinOp(Exp, Assumption ? BO_EQ : BO_NE,
361                        fromInt("0", Ctx.getTypeSize(Ty)), isSigned);
362     }
363
364     llvm_unreachable("Unsupported type for zero value!");
365   }
366
367   // Recursive implementation to unpack and generate symbolic expression.
368   // Sets the hasComparison and RetTy parameters. See getExpr().
369   SMTExprRef getSymExpr(ASTContext &Ctx, SymbolRef Sym, QualType *RetTy,
370                         bool *hasComparison) {
371     if (const SymbolData *SD = dyn_cast<SymbolData>(Sym)) {
372       if (RetTy)
373         *RetTy = Sym->getType();
374
375       return fromData(SD->getSymbolID(), Sym->getType(),
376                       Ctx.getTypeSize(Sym->getType()));
377     }
378
379     if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym)) {
380       if (RetTy)
381         *RetTy = Sym->getType();
382
383       QualType FromTy;
384       SMTExprRef Exp =
385           getSymExpr(Ctx, SC->getOperand(), &FromTy, hasComparison);
386       // Casting an expression with a comparison invalidates it. Note that this
387       // must occur after the recursive call above.
388       // e.g. (signed char) (x > 0)
389       if (hasComparison)
390         *hasComparison = false;
391       return getCastExpr(Ctx, Exp, FromTy, Sym->getType());
392     }
393
394     if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
395       SMTExprRef Exp = getSymBinExpr(Ctx, BSE, hasComparison, RetTy);
396       // Set the hasComparison parameter, in post-order traversal order.
397       if (hasComparison)
398         *hasComparison = BinaryOperator::isComparisonOp(BSE->getOpcode());
399       return Exp;
400     }
401
402     llvm_unreachable("Unsupported SymbolRef type!");
403   }
404
405   // Wrapper to generate SMTExprRef from SymbolCast data.
406   SMTExprRef getCastExpr(ASTContext &Ctx, const SMTExprRef &Exp,
407                          QualType FromTy, QualType ToTy) {
408     return fromCast(Exp, ToTy, Ctx.getTypeSize(ToTy), FromTy,
409                     Ctx.getTypeSize(FromTy));
410   }
411
412   // Wrapper to generate SMTExprRef from BinarySymExpr.
413   // Sets the hasComparison and RetTy parameters. See getSMTExprRef().
414   SMTExprRef getSymBinExpr(ASTContext &Ctx, const BinarySymExpr *BSE,
415                            bool *hasComparison, QualType *RetTy) {
416     QualType LTy, RTy;
417     BinaryOperator::Opcode Op = BSE->getOpcode();
418
419     if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE)) {
420       SMTExprRef LHS = getSymExpr(Ctx, SIE->getLHS(), &LTy, hasComparison);
421       llvm::APSInt NewRInt;
422       std::tie(NewRInt, RTy) = fixAPSInt(Ctx, SIE->getRHS());
423       SMTExprRef RHS = fromAPSInt(NewRInt);
424       return getBinExpr(Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
425     }
426
427     if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE)) {
428       llvm::APSInt NewLInt;
429       std::tie(NewLInt, LTy) = fixAPSInt(Ctx, ISE->getLHS());
430       SMTExprRef LHS = fromAPSInt(NewLInt);
431       SMTExprRef RHS = getSymExpr(Ctx, ISE->getRHS(), &RTy, hasComparison);
432       return getBinExpr(Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
433     }
434
435     if (const SymSymExpr *SSM = dyn_cast<SymSymExpr>(BSE)) {
436       SMTExprRef LHS = getSymExpr(Ctx, SSM->getLHS(), &LTy, hasComparison);
437       SMTExprRef RHS = getSymExpr(Ctx, SSM->getRHS(), &RTy, hasComparison);
438       return getBinExpr(Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
439     }
440
441     llvm_unreachable("Unsupported BinarySymExpr type!");
442   }
443
444   // Wrapper to generate SMTExprRef from unpacked binary symbolic expression.
445   // Sets the RetTy parameter. See getSMTExprRef().
446   SMTExprRef getBinExpr(ASTContext &Ctx, const SMTExprRef &LHS, QualType LTy,
447                         BinaryOperator::Opcode Op, const SMTExprRef &RHS,
448                         QualType RTy, QualType *RetTy) {
449     SMTExprRef NewLHS = LHS;
450     SMTExprRef NewRHS = RHS;
451     doTypeConversion(Ctx, NewLHS, NewRHS, LTy, RTy);
452
453     // Update the return type parameter if the output type has changed.
454     if (RetTy) {
455       // A boolean result can be represented as an integer type in C/C++, but at
456       // this point we only care about the SMT sorts. Set it as a boolean type
457       // to avoid subsequent SMT errors.
458       if (BinaryOperator::isComparisonOp(Op) ||
459           BinaryOperator::isLogicalOp(Op)) {
460         *RetTy = Ctx.BoolTy;
461       } else {
462         *RetTy = LTy;
463       }
464
465       // If the two operands are pointers and the operation is a subtraction,
466       // the result is of type ptrdiff_t, which is signed
467       if (LTy->isAnyPointerType() && RTy->isAnyPointerType() && Op == BO_Sub) {
468         *RetTy = Ctx.getPointerDiffType();
469       }
470     }
471
472     return LTy->isRealFloatingType()
473                ? fromFloatBinOp(NewLHS, Op, NewRHS)
474                : fromBinOp(NewLHS, Op, NewRHS,
475                            LTy->isSignedIntegerOrEnumerationType());
476   }
477
478   // Wrapper to generate SMTExprRef from a range. If From == To, an equality
479   // will be created instead.
480   SMTExprRef getRangeExpr(ASTContext &Ctx, SymbolRef Sym,
481                           const llvm::APSInt &From, const llvm::APSInt &To,
482                           bool InRange) {
483     // Convert lower bound
484     QualType FromTy;
485     llvm::APSInt NewFromInt;
486     std::tie(NewFromInt, FromTy) = fixAPSInt(Ctx, From);
487     SMTExprRef FromExp = fromAPSInt(NewFromInt);
488
489     // Convert symbol
490     QualType SymTy;
491     SMTExprRef Exp = getExpr(Ctx, Sym, &SymTy);
492
493     // Construct single (in)equality
494     if (From == To)
495       return getBinExpr(Ctx, Exp, SymTy, InRange ? BO_EQ : BO_NE, FromExp,
496                         FromTy, /*RetTy=*/nullptr);
497
498     QualType ToTy;
499     llvm::APSInt NewToInt;
500     std::tie(NewToInt, ToTy) = fixAPSInt(Ctx, To);
501     SMTExprRef ToExp = fromAPSInt(NewToInt);
502     assert(FromTy == ToTy && "Range values have different types!");
503
504     // Construct two (in)equalities, and a logical and/or
505     SMTExprRef LHS = getBinExpr(Ctx, Exp, SymTy, InRange ? BO_GE : BO_LT,
506                                 FromExp, FromTy, /*RetTy=*/nullptr);
507     SMTExprRef RHS =
508         getBinExpr(Ctx, Exp, SymTy, InRange ? BO_LE : BO_GT, ToExp, ToTy,
509                    /*RetTy=*/nullptr);
510
511     return fromBinOp(LHS, InRange ? BO_LAnd : BO_LOr, RHS,
512                      SymTy->isSignedIntegerOrEnumerationType());
513   }
514
515   // Recover the QualType of an APSInt.
516   // TODO: Refactor to put elsewhere
517   QualType getAPSIntType(ASTContext &Ctx, const llvm::APSInt &Int) {
518     return Ctx.getIntTypeForBitwidth(Int.getBitWidth(), Int.isSigned());
519   }
520
521   // Get the QualTy for the input APSInt, and fix it if it has a bitwidth of 1.
522   std::pair<llvm::APSInt, QualType> fixAPSInt(ASTContext &Ctx,
523                                               const llvm::APSInt &Int) {
524     llvm::APSInt NewInt;
525
526     // FIXME: This should be a cast from a 1-bit integer type to a boolean type,
527     // but the former is not available in Clang. Instead, extend the APSInt
528     // directly.
529     if (Int.getBitWidth() == 1 && getAPSIntType(Ctx, Int).isNull()) {
530       NewInt = Int.extend(Ctx.getTypeSize(Ctx.BoolTy));
531     } else
532       NewInt = Int;
533
534     return std::make_pair(NewInt, getAPSIntType(Ctx, NewInt));
535   }
536
537   // Perform implicit type conversion on binary symbolic expressions.
538   // May modify all input parameters.
539   // TODO: Refactor to use built-in conversion functions
540   void doTypeConversion(ASTContext &Ctx, SMTExprRef &LHS, SMTExprRef &RHS,
541                         QualType &LTy, QualType &RTy) {
542     assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
543
544     // Perform type conversion
545     if ((LTy->isIntegralOrEnumerationType() &&
546          RTy->isIntegralOrEnumerationType()) &&
547         (LTy->isArithmeticType() && RTy->isArithmeticType())) {
548       doIntTypeConversion<SMTExprRef, &SMTSolver::fromCast>(Ctx, LHS, LTy, RHS,
549                                                             RTy);
550       return;
551     }
552
553     if (LTy->isRealFloatingType() || RTy->isRealFloatingType()) {
554       doFloatTypeConversion<SMTExprRef, &SMTSolver::fromCast>(Ctx, LHS, LTy,
555                                                               RHS, RTy);
556       return;
557     }
558
559     if ((LTy->isAnyPointerType() || RTy->isAnyPointerType()) ||
560         (LTy->isBlockPointerType() || RTy->isBlockPointerType()) ||
561         (LTy->isReferenceType() || RTy->isReferenceType())) {
562       // TODO: Refactor to Sema::FindCompositePointerType(), and
563       // Sema::CheckCompareOperands().
564
565       uint64_t LBitWidth = Ctx.getTypeSize(LTy);
566       uint64_t RBitWidth = Ctx.getTypeSize(RTy);
567
568       // Cast the non-pointer type to the pointer type.
569       // TODO: Be more strict about this.
570       if ((LTy->isAnyPointerType() ^ RTy->isAnyPointerType()) ||
571           (LTy->isBlockPointerType() ^ RTy->isBlockPointerType()) ||
572           (LTy->isReferenceType() ^ RTy->isReferenceType())) {
573         if (LTy->isNullPtrType() || LTy->isBlockPointerType() ||
574             LTy->isReferenceType()) {
575           LHS = fromCast(LHS, RTy, RBitWidth, LTy, LBitWidth);
576           LTy = RTy;
577         } else {
578           RHS = fromCast(RHS, LTy, LBitWidth, RTy, RBitWidth);
579           RTy = LTy;
580         }
581       }
582
583       // Cast the void pointer type to the non-void pointer type.
584       // For void types, this assumes that the casted value is equal to the
585       // value of the original pointer, and does not account for alignment
586       // requirements.
587       if (LTy->isVoidPointerType() ^ RTy->isVoidPointerType()) {
588         assert((Ctx.getTypeSize(LTy) == Ctx.getTypeSize(RTy)) &&
589                "Pointer types have different bitwidths!");
590         if (RTy->isVoidPointerType())
591           RTy = LTy;
592         else
593           LTy = RTy;
594       }
595
596       if (LTy == RTy)
597         return;
598     }
599
600     // Fallback: for the solver, assume that these types don't really matter
601     if ((LTy.getCanonicalType() == RTy.getCanonicalType()) ||
602         (LTy->isObjCObjectPointerType() && RTy->isObjCObjectPointerType())) {
603       LTy = RTy;
604       return;
605     }
606
607     // TODO: Refine behavior for invalid type casts
608   }
609
610   // Perform implicit integer type conversion.
611   // May modify all input parameters.
612   // TODO: Refactor to use Sema::handleIntegerConversion()
613   template <typename T, T (SMTSolver::*doCast)(const T &, QualType, uint64_t,
614                                                QualType, uint64_t)>
615   void doIntTypeConversion(ASTContext &Ctx, T &LHS, QualType &LTy, T &RHS,
616                            QualType &RTy) {
617
618     uint64_t LBitWidth = Ctx.getTypeSize(LTy);
619     uint64_t RBitWidth = Ctx.getTypeSize(RTy);
620
621     assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
622     // Always perform integer promotion before checking type equality.
623     // Otherwise, e.g. (bool) a + (bool) b could trigger a backend assertion
624     if (LTy->isPromotableIntegerType()) {
625       QualType NewTy = Ctx.getPromotedIntegerType(LTy);
626       uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
627       LHS = (this->*doCast)(LHS, NewTy, NewBitWidth, LTy, LBitWidth);
628       LTy = NewTy;
629       LBitWidth = NewBitWidth;
630     }
631     if (RTy->isPromotableIntegerType()) {
632       QualType NewTy = Ctx.getPromotedIntegerType(RTy);
633       uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
634       RHS = (this->*doCast)(RHS, NewTy, NewBitWidth, RTy, RBitWidth);
635       RTy = NewTy;
636       RBitWidth = NewBitWidth;
637     }
638
639     if (LTy == RTy)
640       return;
641
642     // Perform integer type conversion
643     // Note: Safe to skip updating bitwidth because this must terminate
644     bool isLSignedTy = LTy->isSignedIntegerOrEnumerationType();
645     bool isRSignedTy = RTy->isSignedIntegerOrEnumerationType();
646
647     int order = Ctx.getIntegerTypeOrder(LTy, RTy);
648     if (isLSignedTy == isRSignedTy) {
649       // Same signedness; use the higher-ranked type
650       if (order == 1) {
651         RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
652         RTy = LTy;
653       } else {
654         LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
655         LTy = RTy;
656       }
657     } else if (order != (isLSignedTy ? 1 : -1)) {
658       // The unsigned type has greater than or equal rank to the
659       // signed type, so use the unsigned type
660       if (isRSignedTy) {
661         RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
662         RTy = LTy;
663       } else {
664         LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
665         LTy = RTy;
666       }
667     } else if (LBitWidth != RBitWidth) {
668       // The two types are different widths; if we are here, that
669       // means the signed type is larger than the unsigned type, so
670       // use the signed type.
671       if (isLSignedTy) {
672         RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
673         RTy = LTy;
674       } else {
675         LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
676         LTy = RTy;
677       }
678     } else {
679       // The signed type is higher-ranked than the unsigned type,
680       // but isn't actually any bigger (like unsigned int and long
681       // on most 32-bit systems).  Use the unsigned type corresponding
682       // to the signed type.
683       QualType NewTy =
684           Ctx.getCorrespondingUnsignedType(isLSignedTy ? LTy : RTy);
685       RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
686       RTy = NewTy;
687       LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
688       LTy = NewTy;
689     }
690   }
691
692   // Perform implicit floating-point type conversion.
693   // May modify all input parameters.
694   // TODO: Refactor to use Sema::handleFloatConversion()
695   template <typename T, T (SMTSolver::*doCast)(const T &, QualType, uint64_t,
696                                                QualType, uint64_t)>
697   void doFloatTypeConversion(ASTContext &Ctx, T &LHS, QualType &LTy, T &RHS,
698                              QualType &RTy) {
699
700     uint64_t LBitWidth = Ctx.getTypeSize(LTy);
701     uint64_t RBitWidth = Ctx.getTypeSize(RTy);
702
703     // Perform float-point type promotion
704     if (!LTy->isRealFloatingType()) {
705       LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
706       LTy = RTy;
707       LBitWidth = RBitWidth;
708     }
709     if (!RTy->isRealFloatingType()) {
710       RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
711       RTy = LTy;
712       RBitWidth = LBitWidth;
713     }
714
715     if (LTy == RTy)
716       return;
717
718     // If we have two real floating types, convert the smaller operand to the
719     // bigger result
720     // Note: Safe to skip updating bitwidth because this must terminate
721     int order = Ctx.getFloatingTypeOrder(LTy, RTy);
722     if (order > 0) {
723       RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth);
724       RTy = LTy;
725     } else if (order == 0) {
726       LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth);
727       LTy = RTy;
728     } else {
729       llvm_unreachable("Unsupported floating-point type cast!");
730     }
731   }
732
733   // Returns a boolean sort.
734   virtual SMTSortRef getBoolSort() = 0;
735
736   // Returns an appropriate bitvector sort for the given bitwidth.
737   virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
738
739   // Returns a floating-point sort of width 16
740   virtual SMTSortRef getFloat16Sort() = 0;
741
742   // Returns a floating-point sort of width 32
743   virtual SMTSortRef getFloat32Sort() = 0;
744
745   // Returns a floating-point sort of width 64
746   virtual SMTSortRef getFloat64Sort() = 0;
747
748   // Returns a floating-point sort of width 128
749   virtual SMTSortRef getFloat128Sort() = 0;
750
751   // Returns an appropriate sort for the given AST.
752   virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
753
754   // Returns a new SMTExprRef from an SMTExpr
755   virtual SMTExprRef newExprRef(const SMTExpr &E) const = 0;
756
757   /// Given a constraint, adds it to the solver
758   virtual void addConstraint(const SMTExprRef &Exp) const = 0;
759
760   /// Creates a bitvector addition operation
761   virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
762
763   /// Creates a bitvector subtraction operation
764   virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
765
766   /// Creates a bitvector multiplication operation
767   virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
768
769   /// Creates a bitvector signed modulus operation
770   virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
771
772   /// Creates a bitvector unsigned modulus operation
773   virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
774
775   /// Creates a bitvector signed division operation
776   virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
777
778   /// Creates a bitvector unsigned division operation
779   virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
780
781   /// Creates a bitvector logical shift left operation
782   virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
783
784   /// Creates a bitvector arithmetic shift right operation
785   virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
786
787   /// Creates a bitvector logical shift right operation
788   virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
789
790   /// Creates a bitvector negation operation
791   virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
792
793   /// Creates a bitvector not operation
794   virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
795
796   /// Creates a bitvector xor operation
797   virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
798
799   /// Creates a bitvector or operation
800   virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
801
802   /// Creates a bitvector and operation
803   virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
804
805   /// Creates a bitvector unsigned less-than operation
806   virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
807
808   /// Creates a bitvector signed less-than operation
809   virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
810
811   /// Creates a bitvector unsigned greater-than operation
812   virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
813
814   /// Creates a bitvector signed greater-than operation
815   virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
816
817   /// Creates a bitvector unsigned less-equal-than operation
818   virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
819
820   /// Creates a bitvector signed less-equal-than operation
821   virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
822
823   /// Creates a bitvector unsigned greater-equal-than operation
824   virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
825
826   /// Creates a bitvector signed greater-equal-than operation
827   virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
828
829   /// Creates a boolean not operation
830   virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
831
832   /// Creates a boolean equality operation
833   virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
834
835   /// Creates a boolean and operation
836   virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
837
838   /// Creates a boolean or operation
839   virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
840
841   /// Creates a boolean ite operation
842   virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
843                            const SMTExprRef &F) = 0;
844
845   /// Creates a bitvector sign extension operation
846   virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
847
848   /// Creates a bitvector zero extension operation
849   virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
850
851   /// Creates a bitvector extract operation
852   virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
853                                  const SMTExprRef &Exp) = 0;
854
855   /// Creates a bitvector concat operation
856   virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
857                                 const SMTExprRef &RHS) = 0;
858
859   /// Creates a floating-point negation operation
860   virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
861
862   /// Creates a floating-point isInfinite operation
863   virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
864
865   /// Creates a floating-point isNaN operation
866   virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
867
868   /// Creates a floating-point isNormal operation
869   virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
870
871   /// Creates a floating-point isZero operation
872   virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
873
874   /// Creates a floating-point multiplication operation
875   virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
876
877   /// Creates a floating-point division operation
878   virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
879
880   /// Creates a floating-point remainder operation
881   virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
882
883   /// Creates a floating-point addition operation
884   virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
885
886   /// Creates a floating-point subtraction operation
887   virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
888
889   /// Creates a floating-point less-than operation
890   virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
891
892   /// Creates a floating-point greater-than operation
893   virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
894
895   /// Creates a floating-point less-than-or-equal operation
896   virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
897
898   /// Creates a floating-point greater-than-or-equal operation
899   virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
900
901   /// Creates a floating-point equality operation
902   virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
903                                const SMTExprRef &RHS) = 0;
904
905   /// Creates a floating-point conversion from floatint-point to floating-point
906   /// operation
907   virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
908
909   /// Creates a floating-point conversion from floatint-point to signed
910   /// bitvector operation
911   virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From,
912                                const SMTSortRef &To) = 0;
913
914   /// Creates a floating-point conversion from floatint-point to unsigned
915   /// bitvector operation
916   virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From,
917                                const SMTSortRef &To) = 0;
918
919   /// Creates a floating-point conversion from signed bitvector to
920   /// floatint-point operation
921   virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From, unsigned ToWidth) = 0;
922
923   /// Creates a floating-point conversion from unsigned bitvector to
924   /// floatint-point operation
925   virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From, unsigned ToWidth) = 0;
926
927   /// Creates a new symbol, given a name and a sort
928   virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
929
930   // Returns an appropriate floating-point rounding mode.
931   virtual SMTExprRef getFloatRoundingMode() = 0;
932
933   // If the a model is available, returns the value of a given bitvector symbol
934   virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
935                                     bool isUnsigned) = 0;
936
937   // If the a model is available, returns the value of a given boolean symbol
938   virtual bool getBoolean(const SMTExprRef &Exp) = 0;
939
940   /// Constructs an SMTExprRef from a boolean.
941   virtual SMTExprRef mkBoolean(const bool b) = 0;
942
943   /// Constructs an SMTExprRef from a finite APFloat.
944   virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
945
946   /// Constructs an SMTExprRef from an APSInt and its bit width
947   virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
948
949   /// Given an expression, extract the value of this operand in the model.
950   virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
951
952   /// Given an expression extract the value of this operand in the model.
953   virtual bool getInterpretation(const SMTExprRef &Exp,
954                                  llvm::APFloat &Float) = 0;
955
956   /// Construct an SMTExprRef value from a boolean.
957   virtual SMTExprRef fromBoolean(const bool Bool) = 0;
958
959   /// Construct an SMTExprRef value from a finite APFloat.
960   virtual SMTExprRef fromAPFloat(const llvm::APFloat &Float) = 0;
961
962   /// Construct an SMTExprRef value from an APSInt.
963   virtual SMTExprRef fromAPSInt(const llvm::APSInt &Int) = 0;
964
965   /// Construct an SMTExprRef value from an integer.
966   virtual SMTExprRef fromInt(const char *Int, uint64_t BitWidth) = 0;
967
968   /// Construct an SMTExprRef from a SymbolData.
969   virtual SMTExprRef fromData(const SymbolID ID, const QualType &Ty,
970                               uint64_t BitWidth) = 0;
971
972   /// Check if the constraints are satisfiable
973   virtual ConditionTruthVal check() const = 0;
974
975   /// Push the current solver state
976   virtual void push() = 0;
977
978   /// Pop the previous solver state
979   virtual void pop(unsigned NumStates = 1) = 0;
980
981   /// Reset the solver and remove all constraints.
982   virtual void reset() const = 0;
983
984   virtual void print(raw_ostream &OS) const = 0;
985 };
986
987 /// Shared pointer for SMTSolvers.
988 using SMTSolverRef = std::shared_ptr<SMTSolver>;
989
990 /// Convenience method to create and Z3Solver object
991 std::unique_ptr<SMTSolver> CreateZ3Solver();
992
993 } // namespace ento
994 } // namespace clang
995
996 #endif