1 //===- Consumed.cpp --------------------------------------------*- C++ --*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // A intra-procedural analysis for checking consumed properties. This is based,
11 // in part, on research on linear types.
13 //===----------------------------------------------------------------------===//
15 #include "clang/Analysis/Analyses/Consumed.h"
16 #include "clang/AST/ASTContext.h"
17 #include "clang/AST/Attr.h"
18 #include "clang/AST/DeclCXX.h"
19 #include "clang/AST/ExprCXX.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/AST/StmtCXX.h"
22 #include "clang/AST/StmtVisitor.h"
23 #include "clang/AST/Type.h"
24 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
25 #include "clang/Analysis/AnalysisContext.h"
26 #include "clang/Analysis/CFG.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
32 // TODO: Adjust states of args to constructors in the same way that arguments to
33 // function calls are handled.
34 // TODO: Use information from tests in for- and while-loop conditional.
35 // TODO: Add notes about the actual and expected state for
36 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
37 // TODO: Adjust the parser and AttributesList class to support lists of
39 // TODO: Warn about unreachable code.
40 // TODO: Switch to using a bitmap to track unreachable blocks.
41 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
42 // if (valid) ...; (Deferred)
43 // TODO: Take notes on state transitions to provide better warning messages.
45 // TODO: Test nested conditionals: A) Checking the same value multiple times,
46 // and 2) Checking different values. (Deferred)
48 using namespace clang;
49 using namespace consumed;
51 // Key method definition
52 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
54 static SourceLocation getFirstStmtLoc(const CFGBlock *Block) {
55 // Find the source location of the first statement in the block, if the block
57 for (const auto &B : *Block)
58 if (Optional<CFGStmt> CS = B.getAs<CFGStmt>())
59 return CS->getStmt()->getLocStart();
62 // If we have one successor, return the first statement in that block
63 if (Block->succ_size() == 1 && *Block->succ_begin())
64 return getFirstStmtLoc(*Block->succ_begin());
66 return SourceLocation();
69 static SourceLocation getLastStmtLoc(const CFGBlock *Block) {
70 // Find the source location of the last statement in the block, if the block
72 if (const Stmt *StmtNode = Block->getTerminator()) {
73 return StmtNode->getLocStart();
75 for (CFGBlock::const_reverse_iterator BI = Block->rbegin(),
76 BE = Block->rend(); BI != BE; ++BI) {
77 if (Optional<CFGStmt> CS = BI->getAs<CFGStmt>())
78 return CS->getStmt()->getLocStart();
82 // If we have one successor, return the first statement in that block
84 if (Block->succ_size() == 1 && *Block->succ_begin())
85 Loc = getFirstStmtLoc(*Block->succ_begin());
89 // If we have one predecessor, return the last statement in that block
90 if (Block->pred_size() == 1 && *Block->pred_begin())
91 return getLastStmtLoc(*Block->pred_begin());
96 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
101 return CS_Unconsumed;
107 llvm_unreachable("invalid enum");
110 static bool isCallableInState(const CallableWhenAttr *CWAttr,
111 ConsumedState State) {
113 for (const auto &S : CWAttr->callableStates()) {
114 ConsumedState MappedAttrState = CS_None;
117 case CallableWhenAttr::Unknown:
118 MappedAttrState = CS_Unknown;
121 case CallableWhenAttr::Unconsumed:
122 MappedAttrState = CS_Unconsumed;
125 case CallableWhenAttr::Consumed:
126 MappedAttrState = CS_Consumed;
130 if (MappedAttrState == State)
138 static bool isConsumableType(const QualType &QT) {
139 if (QT->isPointerType() || QT->isReferenceType())
142 if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
143 return RD->hasAttr<ConsumableAttr>();
148 static bool isAutoCastType(const QualType &QT) {
149 if (QT->isPointerType() || QT->isReferenceType())
152 if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
153 return RD->hasAttr<ConsumableAutoCastAttr>();
158 static bool isSetOnReadPtrType(const QualType &QT) {
159 if (const CXXRecordDecl *RD = QT->getPointeeCXXRecordDecl())
160 return RD->hasAttr<ConsumableSetOnReadAttr>();
165 static bool isKnownState(ConsumedState State) {
174 llvm_unreachable("invalid enum");
177 static bool isRValueRef(QualType ParamType) {
178 return ParamType->isRValueReferenceType();
181 static bool isTestingFunction(const FunctionDecl *FunDecl) {
182 return FunDecl->hasAttr<TestTypestateAttr>();
185 static bool isPointerOrRef(QualType ParamType) {
186 return ParamType->isPointerType() || ParamType->isReferenceType();
189 static ConsumedState mapConsumableAttrState(const QualType QT) {
190 assert(isConsumableType(QT));
192 const ConsumableAttr *CAttr =
193 QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
195 switch (CAttr->getDefaultState()) {
196 case ConsumableAttr::Unknown:
198 case ConsumableAttr::Unconsumed:
199 return CS_Unconsumed;
200 case ConsumableAttr::Consumed:
203 llvm_unreachable("invalid enum");
207 mapParamTypestateAttrState(const ParamTypestateAttr *PTAttr) {
208 switch (PTAttr->getParamState()) {
209 case ParamTypestateAttr::Unknown:
211 case ParamTypestateAttr::Unconsumed:
212 return CS_Unconsumed;
213 case ParamTypestateAttr::Consumed:
216 llvm_unreachable("invalid_enum");
220 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
221 switch (RTSAttr->getState()) {
222 case ReturnTypestateAttr::Unknown:
224 case ReturnTypestateAttr::Unconsumed:
225 return CS_Unconsumed;
226 case ReturnTypestateAttr::Consumed:
229 llvm_unreachable("invalid enum");
232 static ConsumedState mapSetTypestateAttrState(const SetTypestateAttr *STAttr) {
233 switch (STAttr->getNewState()) {
234 case SetTypestateAttr::Unknown:
236 case SetTypestateAttr::Unconsumed:
237 return CS_Unconsumed;
238 case SetTypestateAttr::Consumed:
241 llvm_unreachable("invalid_enum");
244 static StringRef stateToString(ConsumedState State) {
246 case consumed::CS_None:
249 case consumed::CS_Unknown:
252 case consumed::CS_Unconsumed:
255 case consumed::CS_Consumed:
258 llvm_unreachable("invalid enum");
261 static ConsumedState testsFor(const FunctionDecl *FunDecl) {
262 assert(isTestingFunction(FunDecl));
263 switch (FunDecl->getAttr<TestTypestateAttr>()->getTestState()) {
264 case TestTypestateAttr::Unconsumed:
265 return CS_Unconsumed;
266 case TestTypestateAttr::Consumed:
269 llvm_unreachable("invalid enum");
273 struct VarTestResult {
275 ConsumedState TestsFor;
277 } // end anonymous::VarTestResult
287 class PropagationInfo {
298 const BinaryOperator *Source;
306 VarTestResult VarTest;
308 const CXXBindTemporaryExpr *Tmp;
313 PropagationInfo() : InfoType(IT_None) {}
315 PropagationInfo(const VarTestResult &VarTest)
316 : InfoType(IT_VarTest), VarTest(VarTest) {}
318 PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
319 : InfoType(IT_VarTest) {
322 VarTest.TestsFor = TestsFor;
325 PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
326 const VarTestResult <est, const VarTestResult &RTest)
327 : InfoType(IT_BinTest) {
329 BinTest.Source = Source;
331 BinTest.LTest = LTest;
332 BinTest.RTest = RTest;
335 PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
336 const VarDecl *LVar, ConsumedState LTestsFor,
337 const VarDecl *RVar, ConsumedState RTestsFor)
338 : InfoType(IT_BinTest) {
340 BinTest.Source = Source;
342 BinTest.LTest.Var = LVar;
343 BinTest.LTest.TestsFor = LTestsFor;
344 BinTest.RTest.Var = RVar;
345 BinTest.RTest.TestsFor = RTestsFor;
348 PropagationInfo(ConsumedState State)
349 : InfoType(IT_State), State(State) {}
351 PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
352 PropagationInfo(const CXXBindTemporaryExpr *Tmp)
353 : InfoType(IT_Tmp), Tmp(Tmp) {}
355 const ConsumedState & getState() const {
356 assert(InfoType == IT_State);
360 const VarTestResult & getVarTest() const {
361 assert(InfoType == IT_VarTest);
365 const VarTestResult & getLTest() const {
366 assert(InfoType == IT_BinTest);
367 return BinTest.LTest;
370 const VarTestResult & getRTest() const {
371 assert(InfoType == IT_BinTest);
372 return BinTest.RTest;
375 const VarDecl * getVar() const {
376 assert(InfoType == IT_Var);
380 const CXXBindTemporaryExpr * getTmp() const {
381 assert(InfoType == IT_Tmp);
385 ConsumedState getAsState(const ConsumedStateMap *StateMap) const {
386 assert(isVar() || isTmp() || isState());
389 return StateMap->getState(Var);
391 return StateMap->getState(Tmp);
398 EffectiveOp testEffectiveOp() const {
399 assert(InfoType == IT_BinTest);
403 const BinaryOperator * testSourceNode() const {
404 assert(InfoType == IT_BinTest);
405 return BinTest.Source;
408 inline bool isValid() const { return InfoType != IT_None; }
409 inline bool isState() const { return InfoType == IT_State; }
410 inline bool isVarTest() const { return InfoType == IT_VarTest; }
411 inline bool isBinTest() const { return InfoType == IT_BinTest; }
412 inline bool isVar() const { return InfoType == IT_Var; }
413 inline bool isTmp() const { return InfoType == IT_Tmp; }
415 bool isTest() const {
416 return InfoType == IT_VarTest || InfoType == IT_BinTest;
419 bool isPointerToValue() const {
420 return InfoType == IT_Var || InfoType == IT_Tmp;
423 PropagationInfo invertTest() const {
424 assert(InfoType == IT_VarTest || InfoType == IT_BinTest);
426 if (InfoType == IT_VarTest) {
427 return PropagationInfo(VarTest.Var,
428 invertConsumedUnconsumed(VarTest.TestsFor));
430 } else if (InfoType == IT_BinTest) {
431 return PropagationInfo(BinTest.Source,
432 BinTest.EOp == EO_And ? EO_Or : EO_And,
433 BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
434 BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
436 return PropagationInfo();
442 setStateForVarOrTmp(ConsumedStateMap *StateMap, const PropagationInfo &PInfo,
443 ConsumedState State) {
445 assert(PInfo.isVar() || PInfo.isTmp());
448 StateMap->setState(PInfo.getVar(), State);
450 StateMap->setState(PInfo.getTmp(), State);
453 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
455 typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
456 typedef std::pair<const Stmt *, PropagationInfo> PairType;
457 typedef MapType::iterator InfoEntry;
458 typedef MapType::const_iterator ConstInfoEntry;
460 AnalysisDeclContext &AC;
461 ConsumedAnalyzer &Analyzer;
462 ConsumedStateMap *StateMap;
463 MapType PropagationMap;
465 InfoEntry findInfo(const Expr *E) {
466 if (auto Cleanups = dyn_cast<ExprWithCleanups>(E))
467 if (!Cleanups->cleanupsHaveSideEffects())
468 E = Cleanups->getSubExpr();
469 return PropagationMap.find(E->IgnoreParens());
471 ConstInfoEntry findInfo(const Expr *E) const {
472 if (auto Cleanups = dyn_cast<ExprWithCleanups>(E))
473 if (!Cleanups->cleanupsHaveSideEffects())
474 E = Cleanups->getSubExpr();
475 return PropagationMap.find(E->IgnoreParens());
477 void insertInfo(const Expr *E, const PropagationInfo &PI) {
478 PropagationMap.insert(PairType(E->IgnoreParens(), PI));
481 void forwardInfo(const Expr *From, const Expr *To);
482 void copyInfo(const Expr *From, const Expr *To, ConsumedState CS);
483 ConsumedState getInfo(const Expr *From);
484 void setInfo(const Expr *To, ConsumedState NS);
485 void propagateReturnType(const Expr *Call, const FunctionDecl *Fun);
488 void checkCallability(const PropagationInfo &PInfo,
489 const FunctionDecl *FunDecl,
490 SourceLocation BlameLoc);
491 bool handleCall(const CallExpr *Call, const Expr *ObjArg,
492 const FunctionDecl *FunD);
494 void VisitBinaryOperator(const BinaryOperator *BinOp);
495 void VisitCallExpr(const CallExpr *Call);
496 void VisitCastExpr(const CastExpr *Cast);
497 void VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *Temp);
498 void VisitCXXConstructExpr(const CXXConstructExpr *Call);
499 void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
500 void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
501 void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
502 void VisitDeclStmt(const DeclStmt *DelcS);
503 void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
504 void VisitMemberExpr(const MemberExpr *MExpr);
505 void VisitParmVarDecl(const ParmVarDecl *Param);
506 void VisitReturnStmt(const ReturnStmt *Ret);
507 void VisitUnaryOperator(const UnaryOperator *UOp);
508 void VisitVarDecl(const VarDecl *Var);
510 ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
511 ConsumedStateMap *StateMap)
512 : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
514 PropagationInfo getInfo(const Expr *StmtNode) const {
515 ConstInfoEntry Entry = findInfo(StmtNode);
517 if (Entry != PropagationMap.end())
518 return Entry->second;
520 return PropagationInfo();
523 void reset(ConsumedStateMap *NewStateMap) {
524 StateMap = NewStateMap;
529 void ConsumedStmtVisitor::forwardInfo(const Expr *From, const Expr *To) {
530 InfoEntry Entry = findInfo(From);
531 if (Entry != PropagationMap.end())
532 insertInfo(To, Entry->second);
536 // Create a new state for To, which is initialized to the state of From.
537 // If NS is not CS_None, sets the state of From to NS.
538 void ConsumedStmtVisitor::copyInfo(const Expr *From, const Expr *To,
540 InfoEntry Entry = findInfo(From);
541 if (Entry != PropagationMap.end()) {
542 PropagationInfo& PInfo = Entry->second;
543 ConsumedState CS = PInfo.getAsState(StateMap);
545 insertInfo(To, PropagationInfo(CS));
546 if (NS != CS_None && PInfo.isPointerToValue())
547 setStateForVarOrTmp(StateMap, PInfo, NS);
552 // Get the ConsumedState for From
553 ConsumedState ConsumedStmtVisitor::getInfo(const Expr *From) {
554 InfoEntry Entry = findInfo(From);
555 if (Entry != PropagationMap.end()) {
556 PropagationInfo& PInfo = Entry->second;
557 return PInfo.getAsState(StateMap);
563 // If we already have info for To then update it, otherwise create a new entry.
564 void ConsumedStmtVisitor::setInfo(const Expr *To, ConsumedState NS) {
565 InfoEntry Entry = findInfo(To);
566 if (Entry != PropagationMap.end()) {
567 PropagationInfo& PInfo = Entry->second;
568 if (PInfo.isPointerToValue())
569 setStateForVarOrTmp(StateMap, PInfo, NS);
570 } else if (NS != CS_None) {
571 insertInfo(To, PropagationInfo(NS));
577 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
578 const FunctionDecl *FunDecl,
579 SourceLocation BlameLoc) {
580 assert(!PInfo.isTest());
582 const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
587 ConsumedState VarState = StateMap->getState(PInfo.getVar());
589 if (VarState == CS_None || isCallableInState(CWAttr, VarState))
592 Analyzer.WarningsHandler.warnUseInInvalidState(
593 FunDecl->getNameAsString(), PInfo.getVar()->getNameAsString(),
594 stateToString(VarState), BlameLoc);
597 ConsumedState TmpState = PInfo.getAsState(StateMap);
599 if (TmpState == CS_None || isCallableInState(CWAttr, TmpState))
602 Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
603 FunDecl->getNameAsString(), stateToString(TmpState), BlameLoc);
608 // Factors out common behavior for function, method, and operator calls.
609 // Check parameters and set parameter state if necessary.
610 // Returns true if the state of ObjArg is set, or false otherwise.
611 bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,
612 const FunctionDecl *FunD) {
614 if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))
615 Offset = 1; // first argument is 'this'
617 // check explicit parameters
618 for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
619 // Skip variable argument lists.
620 if (Index - Offset >= FunD->getNumParams())
623 const ParmVarDecl *Param = FunD->getParamDecl(Index - Offset);
624 QualType ParamType = Param->getType();
626 InfoEntry Entry = findInfo(Call->getArg(Index));
628 if (Entry == PropagationMap.end() || Entry->second.isTest())
630 PropagationInfo PInfo = Entry->second;
632 // Check that the parameter is in the correct state.
633 if (ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>()) {
634 ConsumedState ParamState = PInfo.getAsState(StateMap);
635 ConsumedState ExpectedState = mapParamTypestateAttrState(PTA);
637 if (ParamState != ExpectedState)
638 Analyzer.WarningsHandler.warnParamTypestateMismatch(
639 Call->getArg(Index)->getExprLoc(),
640 stateToString(ExpectedState), stateToString(ParamState));
643 if (!(Entry->second.isVar() || Entry->second.isTmp()))
646 // Adjust state on the caller side.
647 if (isRValueRef(ParamType))
648 setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);
649 else if (ReturnTypestateAttr *RT = Param->getAttr<ReturnTypestateAttr>())
650 setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));
651 else if (isPointerOrRef(ParamType) &&
652 (!ParamType->getPointeeType().isConstQualified() ||
653 isSetOnReadPtrType(ParamType)))
654 setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);
660 // check implicit 'self' parameter, if present
661 InfoEntry Entry = findInfo(ObjArg);
662 if (Entry != PropagationMap.end()) {
663 PropagationInfo PInfo = Entry->second;
664 checkCallability(PInfo, FunD, Call->getExprLoc());
666 if (SetTypestateAttr *STA = FunD->getAttr<SetTypestateAttr>()) {
668 StateMap->setState(PInfo.getVar(), mapSetTypestateAttrState(STA));
671 else if (PInfo.isTmp()) {
672 StateMap->setState(PInfo.getTmp(), mapSetTypestateAttrState(STA));
676 else if (isTestingFunction(FunD) && PInfo.isVar()) {
677 PropagationMap.insert(PairType(Call,
678 PropagationInfo(PInfo.getVar(), testsFor(FunD))));
685 void ConsumedStmtVisitor::propagateReturnType(const Expr *Call,
686 const FunctionDecl *Fun) {
687 QualType RetType = Fun->getCallResultType();
688 if (RetType->isReferenceType())
689 RetType = RetType->getPointeeType();
691 if (isConsumableType(RetType)) {
692 ConsumedState ReturnState;
693 if (ReturnTypestateAttr *RTA = Fun->getAttr<ReturnTypestateAttr>())
694 ReturnState = mapReturnTypestateAttrState(RTA);
696 ReturnState = mapConsumableAttrState(RetType);
698 PropagationMap.insert(PairType(Call, PropagationInfo(ReturnState)));
703 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
704 switch (BinOp->getOpcode()) {
707 InfoEntry LEntry = findInfo(BinOp->getLHS()),
708 REntry = findInfo(BinOp->getRHS());
710 VarTestResult LTest, RTest;
712 if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {
713 LTest = LEntry->second.getVarTest();
717 LTest.TestsFor = CS_None;
720 if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {
721 RTest = REntry->second.getVarTest();
725 RTest.TestsFor = CS_None;
728 if (!(LTest.Var == nullptr && RTest.Var == nullptr))
729 PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
730 static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
737 forwardInfo(BinOp->getLHS(), BinOp);
745 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
746 const FunctionDecl *FunDecl = Call->getDirectCallee();
750 // Special case for the std::move function.
751 // TODO: Make this more specific. (Deferred)
752 if (Call->getNumArgs() == 1 && FunDecl->getNameAsString() == "move" &&
753 FunDecl->isInStdNamespace()) {
754 copyInfo(Call->getArg(0), Call, CS_Consumed);
758 handleCall(Call, nullptr, FunDecl);
759 propagateReturnType(Call, FunDecl);
762 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
763 forwardInfo(Cast->getSubExpr(), Cast);
766 void ConsumedStmtVisitor::VisitCXXBindTemporaryExpr(
767 const CXXBindTemporaryExpr *Temp) {
769 InfoEntry Entry = findInfo(Temp->getSubExpr());
771 if (Entry != PropagationMap.end() && !Entry->second.isTest()) {
772 StateMap->setState(Temp, Entry->second.getAsState(StateMap));
773 PropagationMap.insert(PairType(Temp, PropagationInfo(Temp)));
777 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
778 CXXConstructorDecl *Constructor = Call->getConstructor();
780 ASTContext &CurrContext = AC.getASTContext();
781 QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
783 if (!isConsumableType(ThisType))
786 // FIXME: What should happen if someone annotates the move constructor?
787 if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {
788 // TODO: Adjust state of args appropriately.
789 ConsumedState RetState = mapReturnTypestateAttrState(RTA);
790 PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
791 } else if (Constructor->isDefaultConstructor()) {
792 PropagationMap.insert(PairType(Call,
793 PropagationInfo(consumed::CS_Consumed)));
794 } else if (Constructor->isMoveConstructor()) {
795 copyInfo(Call->getArg(0), Call, CS_Consumed);
796 } else if (Constructor->isCopyConstructor()) {
797 // Copy state from arg. If setStateOnRead then set arg to CS_Unknown.
799 isSetOnReadPtrType(Constructor->getThisType(CurrContext)) ?
800 CS_Unknown : CS_None;
801 copyInfo(Call->getArg(0), Call, NS);
803 // TODO: Adjust state of args appropriately.
804 ConsumedState RetState = mapConsumableAttrState(ThisType);
805 PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
810 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
811 const CXXMemberCallExpr *Call) {
812 CXXMethodDecl* MD = Call->getMethodDecl();
816 handleCall(Call, Call->getImplicitObjectArgument(), MD);
817 propagateReturnType(Call, MD);
821 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
822 const CXXOperatorCallExpr *Call) {
824 const FunctionDecl *FunDecl =
825 dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
826 if (!FunDecl) return;
828 if (Call->getOperator() == OO_Equal) {
829 ConsumedState CS = getInfo(Call->getArg(1));
830 if (!handleCall(Call, Call->getArg(0), FunDecl))
831 setInfo(Call->getArg(0), CS);
835 if (const CXXMemberCallExpr *MCall = dyn_cast<CXXMemberCallExpr>(Call))
836 handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);
838 handleCall(Call, Call->getArg(0), FunDecl);
840 propagateReturnType(Call, FunDecl);
843 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
844 if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
845 if (StateMap->getState(Var) != consumed::CS_None)
846 PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
849 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
850 for (const auto *DI : DeclS->decls())
851 if (isa<VarDecl>(DI))
852 VisitVarDecl(cast<VarDecl>(DI));
854 if (DeclS->isSingleDecl())
855 if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
856 PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
859 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
860 const MaterializeTemporaryExpr *Temp) {
862 forwardInfo(Temp->GetTemporaryExpr(), Temp);
865 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
866 forwardInfo(MExpr->getBase(), MExpr);
870 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
871 QualType ParamType = Param->getType();
872 ConsumedState ParamState = consumed::CS_None;
874 if (const ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>())
875 ParamState = mapParamTypestateAttrState(PTA);
876 else if (isConsumableType(ParamType))
877 ParamState = mapConsumableAttrState(ParamType);
878 else if (isRValueRef(ParamType) &&
879 isConsumableType(ParamType->getPointeeType()))
880 ParamState = mapConsumableAttrState(ParamType->getPointeeType());
881 else if (ParamType->isReferenceType() &&
882 isConsumableType(ParamType->getPointeeType()))
883 ParamState = consumed::CS_Unknown;
885 if (ParamState != CS_None)
886 StateMap->setState(Param, ParamState);
889 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
890 ConsumedState ExpectedState = Analyzer.getExpectedReturnState();
892 if (ExpectedState != CS_None) {
893 InfoEntry Entry = findInfo(Ret->getRetValue());
895 if (Entry != PropagationMap.end()) {
896 ConsumedState RetState = Entry->second.getAsState(StateMap);
898 if (RetState != ExpectedState)
899 Analyzer.WarningsHandler.warnReturnTypestateMismatch(
900 Ret->getReturnLoc(), stateToString(ExpectedState),
901 stateToString(RetState));
905 StateMap->checkParamsForReturnTypestate(Ret->getLocStart(),
906 Analyzer.WarningsHandler);
909 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
910 InfoEntry Entry = findInfo(UOp->getSubExpr());
911 if (Entry == PropagationMap.end()) return;
913 switch (UOp->getOpcode()) {
915 PropagationMap.insert(PairType(UOp, Entry->second));
919 if (Entry->second.isTest())
920 PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
928 // TODO: See if I need to check for reference types here.
929 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
930 if (isConsumableType(Var->getType())) {
931 if (Var->hasInit()) {
932 MapType::iterator VIT = findInfo(Var->getInit()->IgnoreImplicit());
933 if (VIT != PropagationMap.end()) {
934 PropagationInfo PInfo = VIT->second;
935 ConsumedState St = PInfo.getAsState(StateMap);
937 if (St != consumed::CS_None) {
938 StateMap->setState(Var, St);
944 StateMap->setState(Var, consumed::CS_Unknown);
947 }} // end clang::consumed::ConsumedStmtVisitor
952 static void splitVarStateForIf(const IfStmt *IfNode, const VarTestResult &Test,
953 ConsumedStateMap *ThenStates,
954 ConsumedStateMap *ElseStates) {
955 ConsumedState VarState = ThenStates->getState(Test.Var);
957 if (VarState == CS_Unknown) {
958 ThenStates->setState(Test.Var, Test.TestsFor);
959 ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
961 } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
962 ThenStates->markUnreachable();
964 } else if (VarState == Test.TestsFor) {
965 ElseStates->markUnreachable();
969 static void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
970 ConsumedStateMap *ThenStates,
971 ConsumedStateMap *ElseStates) {
972 const VarTestResult <est = PInfo.getLTest(),
973 &RTest = PInfo.getRTest();
975 ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
976 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
979 if (PInfo.testEffectiveOp() == EO_And) {
980 if (LState == CS_Unknown) {
981 ThenStates->setState(LTest.Var, LTest.TestsFor);
983 } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
984 ThenStates->markUnreachable();
986 } else if (LState == LTest.TestsFor && isKnownState(RState)) {
987 if (RState == RTest.TestsFor)
988 ElseStates->markUnreachable();
990 ThenStates->markUnreachable();
994 if (LState == CS_Unknown) {
995 ElseStates->setState(LTest.Var,
996 invertConsumedUnconsumed(LTest.TestsFor));
998 } else if (LState == LTest.TestsFor) {
999 ElseStates->markUnreachable();
1001 } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
1002 isKnownState(RState)) {
1004 if (RState == RTest.TestsFor)
1005 ElseStates->markUnreachable();
1007 ThenStates->markUnreachable();
1013 if (PInfo.testEffectiveOp() == EO_And) {
1014 if (RState == CS_Unknown)
1015 ThenStates->setState(RTest.Var, RTest.TestsFor);
1016 else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
1017 ThenStates->markUnreachable();
1020 if (RState == CS_Unknown)
1021 ElseStates->setState(RTest.Var,
1022 invertConsumedUnconsumed(RTest.TestsFor));
1023 else if (RState == RTest.TestsFor)
1024 ElseStates->markUnreachable();
1029 bool ConsumedBlockInfo::allBackEdgesVisited(const CFGBlock *CurrBlock,
1030 const CFGBlock *TargetBlock) {
1032 assert(CurrBlock && "Block pointer must not be NULL");
1033 assert(TargetBlock && "TargetBlock pointer must not be NULL");
1035 unsigned int CurrBlockOrder = VisitOrder[CurrBlock->getBlockID()];
1036 for (CFGBlock::const_pred_iterator PI = TargetBlock->pred_begin(),
1037 PE = TargetBlock->pred_end(); PI != PE; ++PI) {
1038 if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )
1044 void ConsumedBlockInfo::addInfo(
1045 const CFGBlock *Block, ConsumedStateMap *StateMap,
1046 std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {
1048 assert(Block && "Block pointer must not be NULL");
1050 auto &Entry = StateMapsArray[Block->getBlockID()];
1053 Entry->intersect(*StateMap);
1054 } else if (OwnedStateMap)
1055 Entry = std::move(OwnedStateMap);
1057 Entry = llvm::make_unique<ConsumedStateMap>(*StateMap);
1060 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
1061 std::unique_ptr<ConsumedStateMap> StateMap) {
1063 assert(Block && "Block pointer must not be NULL");
1065 auto &Entry = StateMapsArray[Block->getBlockID()];
1068 Entry->intersect(*StateMap);
1070 Entry = std::move(StateMap);
1074 ConsumedStateMap* ConsumedBlockInfo::borrowInfo(const CFGBlock *Block) {
1075 assert(Block && "Block pointer must not be NULL");
1076 assert(StateMapsArray[Block->getBlockID()] && "Block has no block info");
1078 return StateMapsArray[Block->getBlockID()].get();
1081 void ConsumedBlockInfo::discardInfo(const CFGBlock *Block) {
1082 StateMapsArray[Block->getBlockID()] = nullptr;
1085 std::unique_ptr<ConsumedStateMap>
1086 ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
1087 assert(Block && "Block pointer must not be NULL");
1089 auto &Entry = StateMapsArray[Block->getBlockID()];
1090 return isBackEdgeTarget(Block) ? llvm::make_unique<ConsumedStateMap>(*Entry)
1094 bool ConsumedBlockInfo::isBackEdge(const CFGBlock *From, const CFGBlock *To) {
1095 assert(From && "From block must not be NULL");
1096 assert(To && "From block must not be NULL");
1098 return VisitOrder[From->getBlockID()] > VisitOrder[To->getBlockID()];
1101 bool ConsumedBlockInfo::isBackEdgeTarget(const CFGBlock *Block) {
1102 assert(Block && "Block pointer must not be NULL");
1104 // Anything with less than two predecessors can't be the target of a back
1106 if (Block->pred_size() < 2)
1109 unsigned int BlockVisitOrder = VisitOrder[Block->getBlockID()];
1110 for (CFGBlock::const_pred_iterator PI = Block->pred_begin(),
1111 PE = Block->pred_end(); PI != PE; ++PI) {
1112 if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])
1118 void ConsumedStateMap::checkParamsForReturnTypestate(SourceLocation BlameLoc,
1119 ConsumedWarningsHandlerBase &WarningsHandler) const {
1121 for (const auto &DM : VarMap) {
1122 if (isa<ParmVarDecl>(DM.first)) {
1123 const ParmVarDecl *Param = cast<ParmVarDecl>(DM.first);
1124 const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();
1129 ConsumedState ExpectedState = mapReturnTypestateAttrState(RTA);
1130 if (DM.second != ExpectedState)
1131 WarningsHandler.warnParamReturnTypestateMismatch(BlameLoc,
1132 Param->getNameAsString(), stateToString(ExpectedState),
1133 stateToString(DM.second));
1138 void ConsumedStateMap::clearTemporaries() {
1142 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) const {
1143 VarMapType::const_iterator Entry = VarMap.find(Var);
1145 if (Entry != VarMap.end())
1146 return Entry->second;
1152 ConsumedStateMap::getState(const CXXBindTemporaryExpr *Tmp) const {
1153 TmpMapType::const_iterator Entry = TmpMap.find(Tmp);
1155 if (Entry != TmpMap.end())
1156 return Entry->second;
1161 void ConsumedStateMap::intersect(const ConsumedStateMap &Other) {
1162 ConsumedState LocalState;
1164 if (this->From && this->From == Other.From && !Other.Reachable) {
1165 this->markUnreachable();
1169 for (const auto &DM : Other.VarMap) {
1170 LocalState = this->getState(DM.first);
1172 if (LocalState == CS_None)
1175 if (LocalState != DM.second)
1176 VarMap[DM.first] = CS_Unknown;
1180 void ConsumedStateMap::intersectAtLoopHead(const CFGBlock *LoopHead,
1181 const CFGBlock *LoopBack, const ConsumedStateMap *LoopBackStates,
1182 ConsumedWarningsHandlerBase &WarningsHandler) {
1184 ConsumedState LocalState;
1185 SourceLocation BlameLoc = getLastStmtLoc(LoopBack);
1187 for (const auto &DM : LoopBackStates->VarMap) {
1188 LocalState = this->getState(DM.first);
1190 if (LocalState == CS_None)
1193 if (LocalState != DM.second) {
1194 VarMap[DM.first] = CS_Unknown;
1195 WarningsHandler.warnLoopStateMismatch(BlameLoc,
1196 DM.first->getNameAsString());
1201 void ConsumedStateMap::markUnreachable() {
1202 this->Reachable = false;
1207 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
1208 VarMap[Var] = State;
1211 void ConsumedStateMap::setState(const CXXBindTemporaryExpr *Tmp,
1212 ConsumedState State) {
1213 TmpMap[Tmp] = State;
1216 void ConsumedStateMap::remove(const CXXBindTemporaryExpr *Tmp) {
1220 bool ConsumedStateMap::operator!=(const ConsumedStateMap *Other) const {
1221 for (const auto &DM : Other->VarMap)
1222 if (this->getState(DM.first) != DM.second)
1227 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
1228 const FunctionDecl *D) {
1229 QualType ReturnType;
1230 if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1231 ASTContext &CurrContext = AC.getASTContext();
1232 ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
1234 ReturnType = D->getCallResultType();
1236 if (const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>()) {
1237 const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1238 if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1239 // FIXME: This should be removed when template instantiation propagates
1240 // attributes at template specialization definition, not
1241 // declaration. When it is removed the test needs to be enabled
1242 // in SemaDeclAttr.cpp.
1243 WarningsHandler.warnReturnTypestateForUnconsumableType(
1244 RTSAttr->getLocation(), ReturnType.getAsString());
1245 ExpectedReturnState = CS_None;
1247 ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1248 } else if (isConsumableType(ReturnType)) {
1249 if (isAutoCastType(ReturnType)) // We can auto-cast the state to the
1250 ExpectedReturnState = CS_None; // expected state.
1252 ExpectedReturnState = mapConsumableAttrState(ReturnType);
1255 ExpectedReturnState = CS_None;
1258 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1259 const ConsumedStmtVisitor &Visitor) {
1261 std::unique_ptr<ConsumedStateMap> FalseStates(
1262 new ConsumedStateMap(*CurrStates));
1263 PropagationInfo PInfo;
1265 if (const IfStmt *IfNode =
1266 dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1268 const Expr *Cond = IfNode->getCond();
1270 PInfo = Visitor.getInfo(Cond);
1271 if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1272 PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1274 if (PInfo.isVarTest()) {
1275 CurrStates->setSource(Cond);
1276 FalseStates->setSource(Cond);
1277 splitVarStateForIf(IfNode, PInfo.getVarTest(), CurrStates.get(),
1280 } else if (PInfo.isBinTest()) {
1281 CurrStates->setSource(PInfo.testSourceNode());
1282 FalseStates->setSource(PInfo.testSourceNode());
1283 splitVarStateForIfBinOp(PInfo, CurrStates.get(), FalseStates.get());
1289 } else if (const BinaryOperator *BinOp =
1290 dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1292 PInfo = Visitor.getInfo(BinOp->getLHS());
1293 if (!PInfo.isVarTest()) {
1294 if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1295 PInfo = Visitor.getInfo(BinOp->getRHS());
1297 if (!PInfo.isVarTest())
1305 CurrStates->setSource(BinOp);
1306 FalseStates->setSource(BinOp);
1308 const VarTestResult &Test = PInfo.getVarTest();
1309 ConsumedState VarState = CurrStates->getState(Test.Var);
1311 if (BinOp->getOpcode() == BO_LAnd) {
1312 if (VarState == CS_Unknown)
1313 CurrStates->setState(Test.Var, Test.TestsFor);
1314 else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1315 CurrStates->markUnreachable();
1317 } else if (BinOp->getOpcode() == BO_LOr) {
1318 if (VarState == CS_Unknown)
1319 FalseStates->setState(Test.Var,
1320 invertConsumedUnconsumed(Test.TestsFor));
1321 else if (VarState == Test.TestsFor)
1322 FalseStates->markUnreachable();
1329 CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1332 BlockInfo.addInfo(*SI, std::move(CurrStates));
1334 CurrStates = nullptr;
1337 BlockInfo.addInfo(*SI, std::move(FalseStates));
1342 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1343 const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1347 CFG *CFGraph = AC.getCFG();
1351 determineExpectedReturnState(AC, D);
1353 PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1354 // AC.getCFG()->viewCFG(LangOptions());
1356 BlockInfo = ConsumedBlockInfo(CFGraph->getNumBlockIDs(), SortedGraph);
1358 CurrStates = llvm::make_unique<ConsumedStateMap>();
1359 ConsumedStmtVisitor Visitor(AC, *this, CurrStates.get());
1361 // Add all trackable parameters to the state map.
1362 for (const auto *PI : D->parameters())
1363 Visitor.VisitParmVarDecl(PI);
1365 // Visit all of the function's basic blocks.
1366 for (const auto *CurrBlock : *SortedGraph) {
1368 CurrStates = BlockInfo.getInfo(CurrBlock);
1373 } else if (!CurrStates->isReachable()) {
1374 CurrStates = nullptr;
1378 Visitor.reset(CurrStates.get());
1380 // Visit all of the basic block's statements.
1381 for (const auto &B : *CurrBlock) {
1382 switch (B.getKind()) {
1383 case CFGElement::Statement:
1384 Visitor.Visit(B.castAs<CFGStmt>().getStmt());
1387 case CFGElement::TemporaryDtor: {
1388 const CFGTemporaryDtor &DTor = B.castAs<CFGTemporaryDtor>();
1389 const CXXBindTemporaryExpr *BTE = DTor.getBindTemporaryExpr();
1391 Visitor.checkCallability(PropagationInfo(BTE),
1392 DTor.getDestructorDecl(AC.getASTContext()),
1394 CurrStates->remove(BTE);
1398 case CFGElement::AutomaticObjectDtor: {
1399 const CFGAutomaticObjDtor &DTor = B.castAs<CFGAutomaticObjDtor>();
1400 SourceLocation Loc = DTor.getTriggerStmt()->getLocEnd();
1401 const VarDecl *Var = DTor.getVarDecl();
1403 Visitor.checkCallability(PropagationInfo(Var),
1404 DTor.getDestructorDecl(AC.getASTContext()),
1414 // TODO: Handle other forms of branching with precision, including while-
1415 // and for-loops. (Deferred)
1416 if (!splitState(CurrBlock, Visitor)) {
1417 CurrStates->setSource(nullptr);
1419 if (CurrBlock->succ_size() > 1 ||
1420 (CurrBlock->succ_size() == 1 &&
1421 (*CurrBlock->succ_begin())->pred_size() > 1)) {
1423 auto *RawState = CurrStates.get();
1425 for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1426 SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1428 if (*SI == nullptr) continue;
1430 if (BlockInfo.isBackEdge(CurrBlock, *SI)) {
1431 BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(
1432 *SI, CurrBlock, RawState, WarningsHandler);
1434 if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))
1435 BlockInfo.discardInfo(*SI);
1437 BlockInfo.addInfo(*SI, RawState, CurrStates);
1441 CurrStates = nullptr;
1445 if (CurrBlock == &AC.getCFG()->getExit() &&
1446 D->getCallResultType()->isVoidType())
1447 CurrStates->checkParamsForReturnTypestate(D->getLocation(),
1449 } // End of block iterator.
1451 // Delete the last existing state map.
1452 CurrStates = nullptr;
1454 WarningsHandler.emitDiagnostics();
1456 }} // end namespace clang::consumed