1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 /// This file contains functions which are used to decide if a loop worth to be
11 /// unrolled. Moreover, these functions manages the stack of loop which is
12 /// tracked by the ProgramState.
14 //===----------------------------------------------------------------------===//
16 #include "clang/ASTMatchers/ASTMatchers.h"
17 #include "clang/ASTMatchers/ASTMatchFinder.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
20 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
22 using namespace clang;
24 using namespace clang::ast_matchers;
26 static const int MAXIMUM_STEP_UNROLLED = 128;
30 enum Kind { Normal, Unrolled } K;
32 const LocationContext *LCtx;
34 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
35 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
38 static LoopState getNormal(const Stmt *S, const LocationContext *L,
40 return LoopState(Normal, S, L, N);
42 static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
44 return LoopState(Unrolled, S, L, N);
46 bool isUnrolled() const { return K == Unrolled; }
47 unsigned getMaxStep() const { return maxStep; }
48 const Stmt *getLoopStmt() const { return LoopStmt; }
49 const LocationContext *getLocationContext() const { return LCtx; }
50 bool operator==(const LoopState &X) const {
51 return K == X.K && LoopStmt == X.LoopStmt;
53 void Profile(llvm::FoldingSetNodeID &ID) const {
55 ID.AddPointer(LoopStmt);
57 ID.AddInteger(maxStep);
61 // The tracked stack of loops. The stack indicates that which loops the
62 // simulated element contained by. The loops are marked depending if we decided
64 // TODO: The loop stack should not need to be in the program state since it is
65 // lexical in nature. Instead, the stack of loops should be tracked in the
67 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
72 static bool isLoopStmt(const Stmt *S) {
73 return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
76 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
77 auto LS = State->get<LoopStack>();
78 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
79 State = State->set<LoopStack>(LS.getTail());
83 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
84 return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"),
85 hasOperatorName("<="), hasOperatorName(">="),
86 hasOperatorName("!=")),
87 hasEitherOperand(ignoringParenImpCasts(declRefExpr(
88 to(varDecl(hasType(isInteger())).bind(BindName))))),
89 hasEitherOperand(ignoringParenImpCasts(
90 integerLiteral().bind("boundNum"))))
91 .bind("conditionOperator");
94 static internal::Matcher<Stmt>
95 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
97 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
98 hasUnaryOperand(ignoringParenImpCasts(
99 declRefExpr(to(varDecl(VarNodeMatcher)))))),
100 binaryOperator(isAssignmentOperator(),
101 hasLHS(ignoringParenImpCasts(
102 declRefExpr(to(varDecl(VarNodeMatcher)))))));
105 static internal::Matcher<Stmt>
106 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
107 return callExpr(forEachArgumentWithParam(
108 declRefExpr(to(varDecl(VarNodeMatcher))),
109 parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
112 static internal::Matcher<Stmt>
113 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
114 return declStmt(hasDescendant(varDecl(
115 allOf(hasType(referenceType()),
116 hasInitializer(anyOf(
117 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
118 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
121 static internal::Matcher<Stmt>
122 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
123 return unaryOperator(
124 hasOperatorName("&"),
125 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
128 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
129 return hasDescendant(stmt(
130 anyOf(gotoStmt(), switchStmt(), returnStmt(),
131 // Escaping and not known mutation of the loop counter is handled
132 // by exclusion of assigning and address-of operators and
133 // pass-by-ref function calls on the loop counter from the body.
134 changeIntBoundNode(equalsBoundNode(NodeName)),
135 callByRef(equalsBoundNode(NodeName)),
136 getAddrTo(equalsBoundNode(NodeName)),
137 assignedToRef(equalsBoundNode(NodeName)))));
140 static internal::Matcher<Stmt> forLoopMatcher() {
142 hasCondition(simpleCondition("initVarName")),
143 // Initialization should match the form: 'int i = 6' or 'i = 42'.
145 anyOf(declStmt(hasSingleDecl(
146 varDecl(allOf(hasInitializer(ignoringParenImpCasts(
147 integerLiteral().bind("initNum"))),
148 equalsBoundNode("initVarName"))))),
149 binaryOperator(hasLHS(declRefExpr(to(varDecl(
150 equalsBoundNode("initVarName"))))),
151 hasRHS(ignoringParenImpCasts(
152 integerLiteral().bind("initNum")))))),
153 // Incrementation should be a simple increment or decrement
155 hasIncrement(unaryOperator(
156 anyOf(hasOperatorName("++"), hasOperatorName("--")),
157 hasUnaryOperand(declRefExpr(
158 to(varDecl(allOf(equalsBoundNode("initVarName"),
159 hasType(isInteger())))))))),
160 unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
163 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
164 // Global variables assumed as escaped variables.
165 if (VD->hasGlobalStorage())
168 while (!N->pred_empty()) {
169 const Stmt *S = PathDiagnosticLocation::getStmt(N);
171 N = N->getFirstPred();
175 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
176 for (const Decl *D : DS->decls()) {
177 // Once we reach the declaration of the VD we can return.
178 if (D->getCanonicalDecl() == VD)
182 // Check the usage of the pass-by-ref function calls and adress-of operator
183 // on VD and reference initialized by VD.
185 N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
187 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
188 assignedToRef(equalsNode(VD)))),
193 N = N->getFirstPred();
195 llvm_unreachable("Reached root without finding the declaration of VD");
198 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
199 ExplodedNode *Pred, unsigned &maxStep) {
201 if (!isLoopStmt(LoopStmt))
204 // TODO: Match the cases where the bound is not a concrete literal but an
205 // integer with known value
206 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
210 auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
211 llvm::APInt BoundNum =
212 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
213 llvm::APInt InitNum =
214 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
215 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
216 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
217 InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
218 BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
221 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
222 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
224 maxStep = (BoundNum - InitNum).abs().getZExtValue();
226 // Check if the counter of the loop is not escaped before.
227 return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
230 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
231 const Stmt *S = nullptr;
232 while (!N->pred_empty()) {
233 if (N->succ_size() > 1)
236 ProgramPoint P = N->getLocation();
237 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
238 S = BE->getBlock()->getTerminator();
243 N = N->getFirstPred();
246 llvm_unreachable("Reached root without encountering the previous step");
249 // updateLoopStack is called on every basic block, therefore it needs to be fast
250 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
251 ExplodedNode *Pred, unsigned maxVisitOnPath) {
252 auto State = Pred->getState();
253 auto LCtx = Pred->getLocationContext();
255 if (!isLoopStmt(LoopStmt))
258 auto LS = State->get<LoopStack>();
259 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
260 LCtx == LS.getHead().getLocationContext()) {
261 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
262 State = State->set<LoopStack>(LS.getTail());
263 State = State->add<LoopStack>(
264 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
269 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
270 State = State->add<LoopStack>(
271 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
275 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
277 unsigned innerMaxStep = maxStep * outerStep;
278 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
279 State = State->add<LoopStack>(
280 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
282 State = State->add<LoopStack>(
283 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
287 bool isUnrolledState(ProgramStateRef State) {
288 auto LS = State->get<LoopStack>();
289 if (LS.isEmpty() || !LS.getHead().isUnrolled())