1 //===- Consumed.cpp --------------------------------------------*- C++ --*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // A intra-procedural analysis for checking consumed properties. This is based,
11 // in part, on research on linear types.
13 //===----------------------------------------------------------------------===//
15 #include "clang/Analysis/Analyses/Consumed.h"
16 #include "clang/AST/ASTContext.h"
17 #include "clang/AST/Attr.h"
18 #include "clang/AST/DeclCXX.h"
19 #include "clang/AST/ExprCXX.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/AST/StmtCXX.h"
22 #include "clang/AST/StmtVisitor.h"
23 #include "clang/AST/Type.h"
24 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
25 #include "clang/Analysis/AnalysisDeclContext.h"
26 #include "clang/Analysis/CFG.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
32 // TODO: Adjust states of args to constructors in the same way that arguments to
33 // function calls are handled.
34 // TODO: Use information from tests in for- and while-loop conditional.
35 // TODO: Add notes about the actual and expected state for
36 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
37 // TODO: Adjust the parser and AttributesList class to support lists of
39 // TODO: Warn about unreachable code.
40 // TODO: Switch to using a bitmap to track unreachable blocks.
41 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
42 // if (valid) ...; (Deferred)
43 // TODO: Take notes on state transitions to provide better warning messages.
45 // TODO: Test nested conditionals: A) Checking the same value multiple times,
46 // and 2) Checking different values. (Deferred)
48 using namespace clang;
49 using namespace consumed;
51 // Key method definition
52 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
54 static SourceLocation getFirstStmtLoc(const CFGBlock *Block) {
55 // Find the source location of the first statement in the block, if the block
57 for (const auto &B : *Block)
58 if (Optional<CFGStmt> CS = B.getAs<CFGStmt>())
59 return CS->getStmt()->getLocStart();
62 // If we have one successor, return the first statement in that block
63 if (Block->succ_size() == 1 && *Block->succ_begin())
64 return getFirstStmtLoc(*Block->succ_begin());
66 return SourceLocation();
69 static SourceLocation getLastStmtLoc(const CFGBlock *Block) {
70 // Find the source location of the last statement in the block, if the block
72 if (const Stmt *StmtNode = Block->getTerminator()) {
73 return StmtNode->getLocStart();
75 for (CFGBlock::const_reverse_iterator BI = Block->rbegin(),
76 BE = Block->rend(); BI != BE; ++BI) {
77 if (Optional<CFGStmt> CS = BI->getAs<CFGStmt>())
78 return CS->getStmt()->getLocStart();
82 // If we have one successor, return the first statement in that block
84 if (Block->succ_size() == 1 && *Block->succ_begin())
85 Loc = getFirstStmtLoc(*Block->succ_begin());
89 // If we have one predecessor, return the last statement in that block
90 if (Block->pred_size() == 1 && *Block->pred_begin())
91 return getLastStmtLoc(*Block->pred_begin());
96 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
101 return CS_Unconsumed;
107 llvm_unreachable("invalid enum");
110 static bool isCallableInState(const CallableWhenAttr *CWAttr,
111 ConsumedState State) {
113 for (const auto &S : CWAttr->callableStates()) {
114 ConsumedState MappedAttrState = CS_None;
117 case CallableWhenAttr::Unknown:
118 MappedAttrState = CS_Unknown;
121 case CallableWhenAttr::Unconsumed:
122 MappedAttrState = CS_Unconsumed;
125 case CallableWhenAttr::Consumed:
126 MappedAttrState = CS_Consumed;
130 if (MappedAttrState == State)
138 static bool isConsumableType(const QualType &QT) {
139 if (QT->isPointerType() || QT->isReferenceType())
142 if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
143 return RD->hasAttr<ConsumableAttr>();
148 static bool isAutoCastType(const QualType &QT) {
149 if (QT->isPointerType() || QT->isReferenceType())
152 if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
153 return RD->hasAttr<ConsumableAutoCastAttr>();
158 static bool isSetOnReadPtrType(const QualType &QT) {
159 if (const CXXRecordDecl *RD = QT->getPointeeCXXRecordDecl())
160 return RD->hasAttr<ConsumableSetOnReadAttr>();
165 static bool isKnownState(ConsumedState State) {
174 llvm_unreachable("invalid enum");
177 static bool isRValueRef(QualType ParamType) {
178 return ParamType->isRValueReferenceType();
181 static bool isTestingFunction(const FunctionDecl *FunDecl) {
182 return FunDecl->hasAttr<TestTypestateAttr>();
185 static bool isPointerOrRef(QualType ParamType) {
186 return ParamType->isPointerType() || ParamType->isReferenceType();
189 static ConsumedState mapConsumableAttrState(const QualType QT) {
190 assert(isConsumableType(QT));
192 const ConsumableAttr *CAttr =
193 QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
195 switch (CAttr->getDefaultState()) {
196 case ConsumableAttr::Unknown:
198 case ConsumableAttr::Unconsumed:
199 return CS_Unconsumed;
200 case ConsumableAttr::Consumed:
203 llvm_unreachable("invalid enum");
207 mapParamTypestateAttrState(const ParamTypestateAttr *PTAttr) {
208 switch (PTAttr->getParamState()) {
209 case ParamTypestateAttr::Unknown:
211 case ParamTypestateAttr::Unconsumed:
212 return CS_Unconsumed;
213 case ParamTypestateAttr::Consumed:
216 llvm_unreachable("invalid_enum");
220 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
221 switch (RTSAttr->getState()) {
222 case ReturnTypestateAttr::Unknown:
224 case ReturnTypestateAttr::Unconsumed:
225 return CS_Unconsumed;
226 case ReturnTypestateAttr::Consumed:
229 llvm_unreachable("invalid enum");
232 static ConsumedState mapSetTypestateAttrState(const SetTypestateAttr *STAttr) {
233 switch (STAttr->getNewState()) {
234 case SetTypestateAttr::Unknown:
236 case SetTypestateAttr::Unconsumed:
237 return CS_Unconsumed;
238 case SetTypestateAttr::Consumed:
241 llvm_unreachable("invalid_enum");
244 static StringRef stateToString(ConsumedState State) {
246 case consumed::CS_None:
249 case consumed::CS_Unknown:
252 case consumed::CS_Unconsumed:
255 case consumed::CS_Consumed:
258 llvm_unreachable("invalid enum");
261 static ConsumedState testsFor(const FunctionDecl *FunDecl) {
262 assert(isTestingFunction(FunDecl));
263 switch (FunDecl->getAttr<TestTypestateAttr>()->getTestState()) {
264 case TestTypestateAttr::Unconsumed:
265 return CS_Unconsumed;
266 case TestTypestateAttr::Consumed:
269 llvm_unreachable("invalid enum");
273 struct VarTestResult {
275 ConsumedState TestsFor;
277 } // end anonymous::VarTestResult
287 class PropagationInfo {
298 const BinaryOperator *Source;
306 VarTestResult VarTest;
308 const CXXBindTemporaryExpr *Tmp;
313 PropagationInfo() : InfoType(IT_None) {}
315 PropagationInfo(const VarTestResult &VarTest)
316 : InfoType(IT_VarTest), VarTest(VarTest) {}
318 PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
319 : InfoType(IT_VarTest) {
322 VarTest.TestsFor = TestsFor;
325 PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
326 const VarTestResult <est, const VarTestResult &RTest)
327 : InfoType(IT_BinTest) {
329 BinTest.Source = Source;
331 BinTest.LTest = LTest;
332 BinTest.RTest = RTest;
335 PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
336 const VarDecl *LVar, ConsumedState LTestsFor,
337 const VarDecl *RVar, ConsumedState RTestsFor)
338 : InfoType(IT_BinTest) {
340 BinTest.Source = Source;
342 BinTest.LTest.Var = LVar;
343 BinTest.LTest.TestsFor = LTestsFor;
344 BinTest.RTest.Var = RVar;
345 BinTest.RTest.TestsFor = RTestsFor;
348 PropagationInfo(ConsumedState State)
349 : InfoType(IT_State), State(State) {}
351 PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
352 PropagationInfo(const CXXBindTemporaryExpr *Tmp)
353 : InfoType(IT_Tmp), Tmp(Tmp) {}
355 const ConsumedState & getState() const {
356 assert(InfoType == IT_State);
360 const VarTestResult & getVarTest() const {
361 assert(InfoType == IT_VarTest);
365 const VarTestResult & getLTest() const {
366 assert(InfoType == IT_BinTest);
367 return BinTest.LTest;
370 const VarTestResult & getRTest() const {
371 assert(InfoType == IT_BinTest);
372 return BinTest.RTest;
375 const VarDecl * getVar() const {
376 assert(InfoType == IT_Var);
380 const CXXBindTemporaryExpr * getTmp() const {
381 assert(InfoType == IT_Tmp);
385 ConsumedState getAsState(const ConsumedStateMap *StateMap) const {
386 assert(isVar() || isTmp() || isState());
389 return StateMap->getState(Var);
391 return StateMap->getState(Tmp);
398 EffectiveOp testEffectiveOp() const {
399 assert(InfoType == IT_BinTest);
403 const BinaryOperator * testSourceNode() const {
404 assert(InfoType == IT_BinTest);
405 return BinTest.Source;
408 inline bool isValid() const { return InfoType != IT_None; }
409 inline bool isState() const { return InfoType == IT_State; }
410 inline bool isVarTest() const { return InfoType == IT_VarTest; }
411 inline bool isBinTest() const { return InfoType == IT_BinTest; }
412 inline bool isVar() const { return InfoType == IT_Var; }
413 inline bool isTmp() const { return InfoType == IT_Tmp; }
415 bool isTest() const {
416 return InfoType == IT_VarTest || InfoType == IT_BinTest;
419 bool isPointerToValue() const {
420 return InfoType == IT_Var || InfoType == IT_Tmp;
423 PropagationInfo invertTest() const {
424 assert(InfoType == IT_VarTest || InfoType == IT_BinTest);
426 if (InfoType == IT_VarTest) {
427 return PropagationInfo(VarTest.Var,
428 invertConsumedUnconsumed(VarTest.TestsFor));
430 } else if (InfoType == IT_BinTest) {
431 return PropagationInfo(BinTest.Source,
432 BinTest.EOp == EO_And ? EO_Or : EO_And,
433 BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
434 BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
436 return PropagationInfo();
442 setStateForVarOrTmp(ConsumedStateMap *StateMap, const PropagationInfo &PInfo,
443 ConsumedState State) {
445 assert(PInfo.isVar() || PInfo.isTmp());
448 StateMap->setState(PInfo.getVar(), State);
450 StateMap->setState(PInfo.getTmp(), State);
453 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
455 typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
456 typedef std::pair<const Stmt *, PropagationInfo> PairType;
457 typedef MapType::iterator InfoEntry;
458 typedef MapType::const_iterator ConstInfoEntry;
460 AnalysisDeclContext &AC;
461 ConsumedAnalyzer &Analyzer;
462 ConsumedStateMap *StateMap;
463 MapType PropagationMap;
465 InfoEntry findInfo(const Expr *E) {
466 if (auto Cleanups = dyn_cast<ExprWithCleanups>(E))
467 if (!Cleanups->cleanupsHaveSideEffects())
468 E = Cleanups->getSubExpr();
469 return PropagationMap.find(E->IgnoreParens());
471 ConstInfoEntry findInfo(const Expr *E) const {
472 if (auto Cleanups = dyn_cast<ExprWithCleanups>(E))
473 if (!Cleanups->cleanupsHaveSideEffects())
474 E = Cleanups->getSubExpr();
475 return PropagationMap.find(E->IgnoreParens());
477 void insertInfo(const Expr *E, const PropagationInfo &PI) {
478 PropagationMap.insert(PairType(E->IgnoreParens(), PI));
481 void forwardInfo(const Expr *From, const Expr *To);
482 void copyInfo(const Expr *From, const Expr *To, ConsumedState CS);
483 ConsumedState getInfo(const Expr *From);
484 void setInfo(const Expr *To, ConsumedState NS);
485 void propagateReturnType(const Expr *Call, const FunctionDecl *Fun);
488 void checkCallability(const PropagationInfo &PInfo,
489 const FunctionDecl *FunDecl,
490 SourceLocation BlameLoc);
491 bool handleCall(const CallExpr *Call, const Expr *ObjArg,
492 const FunctionDecl *FunD);
494 void VisitBinaryOperator(const BinaryOperator *BinOp);
495 void VisitCallExpr(const CallExpr *Call);
496 void VisitCastExpr(const CastExpr *Cast);
497 void VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *Temp);
498 void VisitCXXConstructExpr(const CXXConstructExpr *Call);
499 void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
500 void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
501 void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
502 void VisitDeclStmt(const DeclStmt *DelcS);
503 void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
504 void VisitMemberExpr(const MemberExpr *MExpr);
505 void VisitParmVarDecl(const ParmVarDecl *Param);
506 void VisitReturnStmt(const ReturnStmt *Ret);
507 void VisitUnaryOperator(const UnaryOperator *UOp);
508 void VisitVarDecl(const VarDecl *Var);
510 ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
511 ConsumedStateMap *StateMap)
512 : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
514 PropagationInfo getInfo(const Expr *StmtNode) const {
515 ConstInfoEntry Entry = findInfo(StmtNode);
517 if (Entry != PropagationMap.end())
518 return Entry->second;
520 return PropagationInfo();
523 void reset(ConsumedStateMap *NewStateMap) {
524 StateMap = NewStateMap;
529 void ConsumedStmtVisitor::forwardInfo(const Expr *From, const Expr *To) {
530 InfoEntry Entry = findInfo(From);
531 if (Entry != PropagationMap.end())
532 insertInfo(To, Entry->second);
536 // Create a new state for To, which is initialized to the state of From.
537 // If NS is not CS_None, sets the state of From to NS.
538 void ConsumedStmtVisitor::copyInfo(const Expr *From, const Expr *To,
540 InfoEntry Entry = findInfo(From);
541 if (Entry != PropagationMap.end()) {
542 PropagationInfo& PInfo = Entry->second;
543 ConsumedState CS = PInfo.getAsState(StateMap);
545 insertInfo(To, PropagationInfo(CS));
546 if (NS != CS_None && PInfo.isPointerToValue())
547 setStateForVarOrTmp(StateMap, PInfo, NS);
552 // Get the ConsumedState for From
553 ConsumedState ConsumedStmtVisitor::getInfo(const Expr *From) {
554 InfoEntry Entry = findInfo(From);
555 if (Entry != PropagationMap.end()) {
556 PropagationInfo& PInfo = Entry->second;
557 return PInfo.getAsState(StateMap);
563 // If we already have info for To then update it, otherwise create a new entry.
564 void ConsumedStmtVisitor::setInfo(const Expr *To, ConsumedState NS) {
565 InfoEntry Entry = findInfo(To);
566 if (Entry != PropagationMap.end()) {
567 PropagationInfo& PInfo = Entry->second;
568 if (PInfo.isPointerToValue())
569 setStateForVarOrTmp(StateMap, PInfo, NS);
570 } else if (NS != CS_None) {
571 insertInfo(To, PropagationInfo(NS));
577 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
578 const FunctionDecl *FunDecl,
579 SourceLocation BlameLoc) {
580 assert(!PInfo.isTest());
582 const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
587 ConsumedState VarState = StateMap->getState(PInfo.getVar());
589 if (VarState == CS_None || isCallableInState(CWAttr, VarState))
592 Analyzer.WarningsHandler.warnUseInInvalidState(
593 FunDecl->getNameAsString(), PInfo.getVar()->getNameAsString(),
594 stateToString(VarState), BlameLoc);
597 ConsumedState TmpState = PInfo.getAsState(StateMap);
599 if (TmpState == CS_None || isCallableInState(CWAttr, TmpState))
602 Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
603 FunDecl->getNameAsString(), stateToString(TmpState), BlameLoc);
608 // Factors out common behavior for function, method, and operator calls.
609 // Check parameters and set parameter state if necessary.
610 // Returns true if the state of ObjArg is set, or false otherwise.
611 bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,
612 const FunctionDecl *FunD) {
614 if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))
615 Offset = 1; // first argument is 'this'
617 // check explicit parameters
618 for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
619 // Skip variable argument lists.
620 if (Index - Offset >= FunD->getNumParams())
623 const ParmVarDecl *Param = FunD->getParamDecl(Index - Offset);
624 QualType ParamType = Param->getType();
626 InfoEntry Entry = findInfo(Call->getArg(Index));
628 if (Entry == PropagationMap.end() || Entry->second.isTest())
630 PropagationInfo PInfo = Entry->second;
632 // Check that the parameter is in the correct state.
633 if (ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>()) {
634 ConsumedState ParamState = PInfo.getAsState(StateMap);
635 ConsumedState ExpectedState = mapParamTypestateAttrState(PTA);
637 if (ParamState != ExpectedState)
638 Analyzer.WarningsHandler.warnParamTypestateMismatch(
639 Call->getArg(Index)->getExprLoc(),
640 stateToString(ExpectedState), stateToString(ParamState));
643 if (!(Entry->second.isVar() || Entry->second.isTmp()))
646 // Adjust state on the caller side.
647 if (isRValueRef(ParamType))
648 setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);
649 else if (ReturnTypestateAttr *RT = Param->getAttr<ReturnTypestateAttr>())
650 setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));
651 else if (isPointerOrRef(ParamType) &&
652 (!ParamType->getPointeeType().isConstQualified() ||
653 isSetOnReadPtrType(ParamType)))
654 setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);
660 // check implicit 'self' parameter, if present
661 InfoEntry Entry = findInfo(ObjArg);
662 if (Entry != PropagationMap.end()) {
663 PropagationInfo PInfo = Entry->second;
664 checkCallability(PInfo, FunD, Call->getExprLoc());
666 if (SetTypestateAttr *STA = FunD->getAttr<SetTypestateAttr>()) {
668 StateMap->setState(PInfo.getVar(), mapSetTypestateAttrState(STA));
671 else if (PInfo.isTmp()) {
672 StateMap->setState(PInfo.getTmp(), mapSetTypestateAttrState(STA));
676 else if (isTestingFunction(FunD) && PInfo.isVar()) {
677 PropagationMap.insert(PairType(Call,
678 PropagationInfo(PInfo.getVar(), testsFor(FunD))));
685 void ConsumedStmtVisitor::propagateReturnType(const Expr *Call,
686 const FunctionDecl *Fun) {
687 QualType RetType = Fun->getCallResultType();
688 if (RetType->isReferenceType())
689 RetType = RetType->getPointeeType();
691 if (isConsumableType(RetType)) {
692 ConsumedState ReturnState;
693 if (ReturnTypestateAttr *RTA = Fun->getAttr<ReturnTypestateAttr>())
694 ReturnState = mapReturnTypestateAttrState(RTA);
696 ReturnState = mapConsumableAttrState(RetType);
698 PropagationMap.insert(PairType(Call, PropagationInfo(ReturnState)));
703 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
704 switch (BinOp->getOpcode()) {
707 InfoEntry LEntry = findInfo(BinOp->getLHS()),
708 REntry = findInfo(BinOp->getRHS());
710 VarTestResult LTest, RTest;
712 if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {
713 LTest = LEntry->second.getVarTest();
717 LTest.TestsFor = CS_None;
720 if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {
721 RTest = REntry->second.getVarTest();
725 RTest.TestsFor = CS_None;
728 if (!(LTest.Var == nullptr && RTest.Var == nullptr))
729 PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
730 static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
737 forwardInfo(BinOp->getLHS(), BinOp);
745 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
746 const FunctionDecl *FunDecl = Call->getDirectCallee();
750 // Special case for the std::move function.
751 // TODO: Make this more specific. (Deferred)
752 if (Call->isCallToStdMove()) {
753 copyInfo(Call->getArg(0), Call, CS_Consumed);
757 handleCall(Call, nullptr, FunDecl);
758 propagateReturnType(Call, FunDecl);
761 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
762 forwardInfo(Cast->getSubExpr(), Cast);
765 void ConsumedStmtVisitor::VisitCXXBindTemporaryExpr(
766 const CXXBindTemporaryExpr *Temp) {
768 InfoEntry Entry = findInfo(Temp->getSubExpr());
770 if (Entry != PropagationMap.end() && !Entry->second.isTest()) {
771 StateMap->setState(Temp, Entry->second.getAsState(StateMap));
772 PropagationMap.insert(PairType(Temp, PropagationInfo(Temp)));
776 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
777 CXXConstructorDecl *Constructor = Call->getConstructor();
779 ASTContext &CurrContext = AC.getASTContext();
780 QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
782 if (!isConsumableType(ThisType))
785 // FIXME: What should happen if someone annotates the move constructor?
786 if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {
787 // TODO: Adjust state of args appropriately.
788 ConsumedState RetState = mapReturnTypestateAttrState(RTA);
789 PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
790 } else if (Constructor->isDefaultConstructor()) {
791 PropagationMap.insert(PairType(Call,
792 PropagationInfo(consumed::CS_Consumed)));
793 } else if (Constructor->isMoveConstructor()) {
794 copyInfo(Call->getArg(0), Call, CS_Consumed);
795 } else if (Constructor->isCopyConstructor()) {
796 // Copy state from arg. If setStateOnRead then set arg to CS_Unknown.
798 isSetOnReadPtrType(Constructor->getThisType(CurrContext)) ?
799 CS_Unknown : CS_None;
800 copyInfo(Call->getArg(0), Call, NS);
802 // TODO: Adjust state of args appropriately.
803 ConsumedState RetState = mapConsumableAttrState(ThisType);
804 PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
809 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
810 const CXXMemberCallExpr *Call) {
811 CXXMethodDecl* MD = Call->getMethodDecl();
815 handleCall(Call, Call->getImplicitObjectArgument(), MD);
816 propagateReturnType(Call, MD);
820 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
821 const CXXOperatorCallExpr *Call) {
823 const FunctionDecl *FunDecl =
824 dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
825 if (!FunDecl) return;
827 if (Call->getOperator() == OO_Equal) {
828 ConsumedState CS = getInfo(Call->getArg(1));
829 if (!handleCall(Call, Call->getArg(0), FunDecl))
830 setInfo(Call->getArg(0), CS);
834 if (const CXXMemberCallExpr *MCall = dyn_cast<CXXMemberCallExpr>(Call))
835 handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);
837 handleCall(Call, Call->getArg(0), FunDecl);
839 propagateReturnType(Call, FunDecl);
842 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
843 if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
844 if (StateMap->getState(Var) != consumed::CS_None)
845 PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
848 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
849 for (const auto *DI : DeclS->decls())
850 if (isa<VarDecl>(DI))
851 VisitVarDecl(cast<VarDecl>(DI));
853 if (DeclS->isSingleDecl())
854 if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
855 PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
858 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
859 const MaterializeTemporaryExpr *Temp) {
861 forwardInfo(Temp->GetTemporaryExpr(), Temp);
864 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
865 forwardInfo(MExpr->getBase(), MExpr);
869 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
870 QualType ParamType = Param->getType();
871 ConsumedState ParamState = consumed::CS_None;
873 if (const ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>())
874 ParamState = mapParamTypestateAttrState(PTA);
875 else if (isConsumableType(ParamType))
876 ParamState = mapConsumableAttrState(ParamType);
877 else if (isRValueRef(ParamType) &&
878 isConsumableType(ParamType->getPointeeType()))
879 ParamState = mapConsumableAttrState(ParamType->getPointeeType());
880 else if (ParamType->isReferenceType() &&
881 isConsumableType(ParamType->getPointeeType()))
882 ParamState = consumed::CS_Unknown;
884 if (ParamState != CS_None)
885 StateMap->setState(Param, ParamState);
888 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
889 ConsumedState ExpectedState = Analyzer.getExpectedReturnState();
891 if (ExpectedState != CS_None) {
892 InfoEntry Entry = findInfo(Ret->getRetValue());
894 if (Entry != PropagationMap.end()) {
895 ConsumedState RetState = Entry->second.getAsState(StateMap);
897 if (RetState != ExpectedState)
898 Analyzer.WarningsHandler.warnReturnTypestateMismatch(
899 Ret->getReturnLoc(), stateToString(ExpectedState),
900 stateToString(RetState));
904 StateMap->checkParamsForReturnTypestate(Ret->getLocStart(),
905 Analyzer.WarningsHandler);
908 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
909 InfoEntry Entry = findInfo(UOp->getSubExpr());
910 if (Entry == PropagationMap.end()) return;
912 switch (UOp->getOpcode()) {
914 PropagationMap.insert(PairType(UOp, Entry->second));
918 if (Entry->second.isTest())
919 PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
927 // TODO: See if I need to check for reference types here.
928 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
929 if (isConsumableType(Var->getType())) {
930 if (Var->hasInit()) {
931 MapType::iterator VIT = findInfo(Var->getInit()->IgnoreImplicit());
932 if (VIT != PropagationMap.end()) {
933 PropagationInfo PInfo = VIT->second;
934 ConsumedState St = PInfo.getAsState(StateMap);
936 if (St != consumed::CS_None) {
937 StateMap->setState(Var, St);
943 StateMap->setState(Var, consumed::CS_Unknown);
946 }} // end clang::consumed::ConsumedStmtVisitor
951 static void splitVarStateForIf(const IfStmt *IfNode, const VarTestResult &Test,
952 ConsumedStateMap *ThenStates,
953 ConsumedStateMap *ElseStates) {
954 ConsumedState VarState = ThenStates->getState(Test.Var);
956 if (VarState == CS_Unknown) {
957 ThenStates->setState(Test.Var, Test.TestsFor);
958 ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
960 } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
961 ThenStates->markUnreachable();
963 } else if (VarState == Test.TestsFor) {
964 ElseStates->markUnreachable();
968 static void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
969 ConsumedStateMap *ThenStates,
970 ConsumedStateMap *ElseStates) {
971 const VarTestResult <est = PInfo.getLTest(),
972 &RTest = PInfo.getRTest();
974 ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
975 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
978 if (PInfo.testEffectiveOp() == EO_And) {
979 if (LState == CS_Unknown) {
980 ThenStates->setState(LTest.Var, LTest.TestsFor);
982 } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
983 ThenStates->markUnreachable();
985 } else if (LState == LTest.TestsFor && isKnownState(RState)) {
986 if (RState == RTest.TestsFor)
987 ElseStates->markUnreachable();
989 ThenStates->markUnreachable();
993 if (LState == CS_Unknown) {
994 ElseStates->setState(LTest.Var,
995 invertConsumedUnconsumed(LTest.TestsFor));
997 } else if (LState == LTest.TestsFor) {
998 ElseStates->markUnreachable();
1000 } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
1001 isKnownState(RState)) {
1003 if (RState == RTest.TestsFor)
1004 ElseStates->markUnreachable();
1006 ThenStates->markUnreachable();
1012 if (PInfo.testEffectiveOp() == EO_And) {
1013 if (RState == CS_Unknown)
1014 ThenStates->setState(RTest.Var, RTest.TestsFor);
1015 else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
1016 ThenStates->markUnreachable();
1019 if (RState == CS_Unknown)
1020 ElseStates->setState(RTest.Var,
1021 invertConsumedUnconsumed(RTest.TestsFor));
1022 else if (RState == RTest.TestsFor)
1023 ElseStates->markUnreachable();
1028 bool ConsumedBlockInfo::allBackEdgesVisited(const CFGBlock *CurrBlock,
1029 const CFGBlock *TargetBlock) {
1031 assert(CurrBlock && "Block pointer must not be NULL");
1032 assert(TargetBlock && "TargetBlock pointer must not be NULL");
1034 unsigned int CurrBlockOrder = VisitOrder[CurrBlock->getBlockID()];
1035 for (CFGBlock::const_pred_iterator PI = TargetBlock->pred_begin(),
1036 PE = TargetBlock->pred_end(); PI != PE; ++PI) {
1037 if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )
1043 void ConsumedBlockInfo::addInfo(
1044 const CFGBlock *Block, ConsumedStateMap *StateMap,
1045 std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {
1047 assert(Block && "Block pointer must not be NULL");
1049 auto &Entry = StateMapsArray[Block->getBlockID()];
1052 Entry->intersect(*StateMap);
1053 } else if (OwnedStateMap)
1054 Entry = std::move(OwnedStateMap);
1056 Entry = llvm::make_unique<ConsumedStateMap>(*StateMap);
1059 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
1060 std::unique_ptr<ConsumedStateMap> StateMap) {
1062 assert(Block && "Block pointer must not be NULL");
1064 auto &Entry = StateMapsArray[Block->getBlockID()];
1067 Entry->intersect(*StateMap);
1069 Entry = std::move(StateMap);
1073 ConsumedStateMap* ConsumedBlockInfo::borrowInfo(const CFGBlock *Block) {
1074 assert(Block && "Block pointer must not be NULL");
1075 assert(StateMapsArray[Block->getBlockID()] && "Block has no block info");
1077 return StateMapsArray[Block->getBlockID()].get();
1080 void ConsumedBlockInfo::discardInfo(const CFGBlock *Block) {
1081 StateMapsArray[Block->getBlockID()] = nullptr;
1084 std::unique_ptr<ConsumedStateMap>
1085 ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
1086 assert(Block && "Block pointer must not be NULL");
1088 auto &Entry = StateMapsArray[Block->getBlockID()];
1089 return isBackEdgeTarget(Block) ? llvm::make_unique<ConsumedStateMap>(*Entry)
1093 bool ConsumedBlockInfo::isBackEdge(const CFGBlock *From, const CFGBlock *To) {
1094 assert(From && "From block must not be NULL");
1095 assert(To && "From block must not be NULL");
1097 return VisitOrder[From->getBlockID()] > VisitOrder[To->getBlockID()];
1100 bool ConsumedBlockInfo::isBackEdgeTarget(const CFGBlock *Block) {
1101 assert(Block && "Block pointer must not be NULL");
1103 // Anything with less than two predecessors can't be the target of a back
1105 if (Block->pred_size() < 2)
1108 unsigned int BlockVisitOrder = VisitOrder[Block->getBlockID()];
1109 for (CFGBlock::const_pred_iterator PI = Block->pred_begin(),
1110 PE = Block->pred_end(); PI != PE; ++PI) {
1111 if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])
1117 void ConsumedStateMap::checkParamsForReturnTypestate(SourceLocation BlameLoc,
1118 ConsumedWarningsHandlerBase &WarningsHandler) const {
1120 for (const auto &DM : VarMap) {
1121 if (isa<ParmVarDecl>(DM.first)) {
1122 const ParmVarDecl *Param = cast<ParmVarDecl>(DM.first);
1123 const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();
1128 ConsumedState ExpectedState = mapReturnTypestateAttrState(RTA);
1129 if (DM.second != ExpectedState)
1130 WarningsHandler.warnParamReturnTypestateMismatch(BlameLoc,
1131 Param->getNameAsString(), stateToString(ExpectedState),
1132 stateToString(DM.second));
1137 void ConsumedStateMap::clearTemporaries() {
1141 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) const {
1142 VarMapType::const_iterator Entry = VarMap.find(Var);
1144 if (Entry != VarMap.end())
1145 return Entry->second;
1151 ConsumedStateMap::getState(const CXXBindTemporaryExpr *Tmp) const {
1152 TmpMapType::const_iterator Entry = TmpMap.find(Tmp);
1154 if (Entry != TmpMap.end())
1155 return Entry->second;
1160 void ConsumedStateMap::intersect(const ConsumedStateMap &Other) {
1161 ConsumedState LocalState;
1163 if (this->From && this->From == Other.From && !Other.Reachable) {
1164 this->markUnreachable();
1168 for (const auto &DM : Other.VarMap) {
1169 LocalState = this->getState(DM.first);
1171 if (LocalState == CS_None)
1174 if (LocalState != DM.second)
1175 VarMap[DM.first] = CS_Unknown;
1179 void ConsumedStateMap::intersectAtLoopHead(const CFGBlock *LoopHead,
1180 const CFGBlock *LoopBack, const ConsumedStateMap *LoopBackStates,
1181 ConsumedWarningsHandlerBase &WarningsHandler) {
1183 ConsumedState LocalState;
1184 SourceLocation BlameLoc = getLastStmtLoc(LoopBack);
1186 for (const auto &DM : LoopBackStates->VarMap) {
1187 LocalState = this->getState(DM.first);
1189 if (LocalState == CS_None)
1192 if (LocalState != DM.second) {
1193 VarMap[DM.first] = CS_Unknown;
1194 WarningsHandler.warnLoopStateMismatch(BlameLoc,
1195 DM.first->getNameAsString());
1200 void ConsumedStateMap::markUnreachable() {
1201 this->Reachable = false;
1206 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
1207 VarMap[Var] = State;
1210 void ConsumedStateMap::setState(const CXXBindTemporaryExpr *Tmp,
1211 ConsumedState State) {
1212 TmpMap[Tmp] = State;
1215 void ConsumedStateMap::remove(const CXXBindTemporaryExpr *Tmp) {
1219 bool ConsumedStateMap::operator!=(const ConsumedStateMap *Other) const {
1220 for (const auto &DM : Other->VarMap)
1221 if (this->getState(DM.first) != DM.second)
1226 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
1227 const FunctionDecl *D) {
1228 QualType ReturnType;
1229 if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1230 ASTContext &CurrContext = AC.getASTContext();
1231 ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
1233 ReturnType = D->getCallResultType();
1235 if (const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>()) {
1236 const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1237 if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1238 // FIXME: This should be removed when template instantiation propagates
1239 // attributes at template specialization definition, not
1240 // declaration. When it is removed the test needs to be enabled
1241 // in SemaDeclAttr.cpp.
1242 WarningsHandler.warnReturnTypestateForUnconsumableType(
1243 RTSAttr->getLocation(), ReturnType.getAsString());
1244 ExpectedReturnState = CS_None;
1246 ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1247 } else if (isConsumableType(ReturnType)) {
1248 if (isAutoCastType(ReturnType)) // We can auto-cast the state to the
1249 ExpectedReturnState = CS_None; // expected state.
1251 ExpectedReturnState = mapConsumableAttrState(ReturnType);
1254 ExpectedReturnState = CS_None;
1257 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1258 const ConsumedStmtVisitor &Visitor) {
1260 std::unique_ptr<ConsumedStateMap> FalseStates(
1261 new ConsumedStateMap(*CurrStates));
1262 PropagationInfo PInfo;
1264 if (const IfStmt *IfNode =
1265 dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1267 const Expr *Cond = IfNode->getCond();
1269 PInfo = Visitor.getInfo(Cond);
1270 if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1271 PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1273 if (PInfo.isVarTest()) {
1274 CurrStates->setSource(Cond);
1275 FalseStates->setSource(Cond);
1276 splitVarStateForIf(IfNode, PInfo.getVarTest(), CurrStates.get(),
1279 } else if (PInfo.isBinTest()) {
1280 CurrStates->setSource(PInfo.testSourceNode());
1281 FalseStates->setSource(PInfo.testSourceNode());
1282 splitVarStateForIfBinOp(PInfo, CurrStates.get(), FalseStates.get());
1288 } else if (const BinaryOperator *BinOp =
1289 dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1291 PInfo = Visitor.getInfo(BinOp->getLHS());
1292 if (!PInfo.isVarTest()) {
1293 if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1294 PInfo = Visitor.getInfo(BinOp->getRHS());
1296 if (!PInfo.isVarTest())
1304 CurrStates->setSource(BinOp);
1305 FalseStates->setSource(BinOp);
1307 const VarTestResult &Test = PInfo.getVarTest();
1308 ConsumedState VarState = CurrStates->getState(Test.Var);
1310 if (BinOp->getOpcode() == BO_LAnd) {
1311 if (VarState == CS_Unknown)
1312 CurrStates->setState(Test.Var, Test.TestsFor);
1313 else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1314 CurrStates->markUnreachable();
1316 } else if (BinOp->getOpcode() == BO_LOr) {
1317 if (VarState == CS_Unknown)
1318 FalseStates->setState(Test.Var,
1319 invertConsumedUnconsumed(Test.TestsFor));
1320 else if (VarState == Test.TestsFor)
1321 FalseStates->markUnreachable();
1328 CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1331 BlockInfo.addInfo(*SI, std::move(CurrStates));
1333 CurrStates = nullptr;
1336 BlockInfo.addInfo(*SI, std::move(FalseStates));
1341 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1342 const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1346 CFG *CFGraph = AC.getCFG();
1350 determineExpectedReturnState(AC, D);
1352 PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1353 // AC.getCFG()->viewCFG(LangOptions());
1355 BlockInfo = ConsumedBlockInfo(CFGraph->getNumBlockIDs(), SortedGraph);
1357 CurrStates = llvm::make_unique<ConsumedStateMap>();
1358 ConsumedStmtVisitor Visitor(AC, *this, CurrStates.get());
1360 // Add all trackable parameters to the state map.
1361 for (const auto *PI : D->parameters())
1362 Visitor.VisitParmVarDecl(PI);
1364 // Visit all of the function's basic blocks.
1365 for (const auto *CurrBlock : *SortedGraph) {
1367 CurrStates = BlockInfo.getInfo(CurrBlock);
1372 } else if (!CurrStates->isReachable()) {
1373 CurrStates = nullptr;
1377 Visitor.reset(CurrStates.get());
1379 // Visit all of the basic block's statements.
1380 for (const auto &B : *CurrBlock) {
1381 switch (B.getKind()) {
1382 case CFGElement::Statement:
1383 Visitor.Visit(B.castAs<CFGStmt>().getStmt());
1386 case CFGElement::TemporaryDtor: {
1387 const CFGTemporaryDtor &DTor = B.castAs<CFGTemporaryDtor>();
1388 const CXXBindTemporaryExpr *BTE = DTor.getBindTemporaryExpr();
1390 Visitor.checkCallability(PropagationInfo(BTE),
1391 DTor.getDestructorDecl(AC.getASTContext()),
1393 CurrStates->remove(BTE);
1397 case CFGElement::AutomaticObjectDtor: {
1398 const CFGAutomaticObjDtor &DTor = B.castAs<CFGAutomaticObjDtor>();
1399 SourceLocation Loc = DTor.getTriggerStmt()->getLocEnd();
1400 const VarDecl *Var = DTor.getVarDecl();
1402 Visitor.checkCallability(PropagationInfo(Var),
1403 DTor.getDestructorDecl(AC.getASTContext()),
1413 // TODO: Handle other forms of branching with precision, including while-
1414 // and for-loops. (Deferred)
1415 if (!splitState(CurrBlock, Visitor)) {
1416 CurrStates->setSource(nullptr);
1418 if (CurrBlock->succ_size() > 1 ||
1419 (CurrBlock->succ_size() == 1 &&
1420 (*CurrBlock->succ_begin())->pred_size() > 1)) {
1422 auto *RawState = CurrStates.get();
1424 for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1425 SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1427 if (*SI == nullptr) continue;
1429 if (BlockInfo.isBackEdge(CurrBlock, *SI)) {
1430 BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(
1431 *SI, CurrBlock, RawState, WarningsHandler);
1433 if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))
1434 BlockInfo.discardInfo(*SI);
1436 BlockInfo.addInfo(*SI, RawState, CurrStates);
1440 CurrStates = nullptr;
1444 if (CurrBlock == &AC.getCFG()->getExit() &&
1445 D->getCallResultType()->isVoidType())
1446 CurrStates->checkParamsForReturnTypestate(D->getLocation(),
1448 } // End of block iterator.
1450 // Delete the last existing state map.
1451 CurrStates = nullptr;
1453 WarningsHandler.emitDiagnostics();
1455 }} // end namespace clang::consumed