1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Adds brackets in case statements that "contain" initialization of retaining
11 // variable, thus emitting the "switch case is in protected scope" error.
13 //===----------------------------------------------------------------------===//
15 #include "Transforms.h"
16 #include "Internals.h"
17 #include "clang/AST/ASTContext.h"
18 #include "clang/Sema/SemaDiagnostic.h"
20 using namespace clang;
21 using namespace arcmt;
22 using namespace trans;
26 class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
27 SmallVectorImpl<DeclRefExpr *> &Refs;
30 LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
33 bool VisitDeclRefExpr(DeclRefExpr *E) {
34 if (ValueDecl *D = E->getDecl())
35 if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
50 CaseInfo() : SC(nullptr), State(St_Unchecked) {}
51 CaseInfo(SwitchCase *S, SourceRange Range)
52 : SC(S), Range(Range), State(St_Unchecked) {}
55 class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
57 SmallVectorImpl<CaseInfo> &Cases;
60 CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
61 : PMap(PMap), Cases(Cases) { }
63 bool VisitSwitchStmt(SwitchStmt *S) {
64 SwitchCase *Curr = S->getSwitchCaseList();
67 Stmt *Parent = getCaseParent(Curr);
68 Curr = Curr->getNextSwitchCase();
69 // Make sure all case statements are in the same scope.
71 if (getCaseParent(Curr) != Parent)
73 Curr = Curr->getNextSwitchCase();
76 SourceLocation NextLoc = S->getEndLoc();
77 Curr = S->getSwitchCaseList();
78 // We iterate over case statements in reverse source-order.
81 CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc)));
82 NextLoc = Curr->getBeginLoc();
83 Curr = Curr->getNextSwitchCase();
88 Stmt *getCaseParent(SwitchCase *S) {
89 Stmt *Parent = PMap.getParent(S);
90 while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
91 Parent = PMap.getParent(Parent);
96 class ProtectedScopeFixer {
99 SmallVector<CaseInfo, 16> Cases;
100 SmallVector<DeclRefExpr *, 16> LocalRefs;
103 ProtectedScopeFixer(BodyContext &BodyCtx)
104 : Pass(BodyCtx.getMigrationContext().Pass),
105 SM(Pass.Ctx.getSourceManager()) {
107 CaseCollector(BodyCtx.getParentMap(), Cases)
108 .TraverseStmt(BodyCtx.getTopStmt());
109 LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
111 SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
112 const CapturedDiagList &DiagList = Pass.getDiags();
113 // Copy the diagnostics so we don't have to worry about invaliding iterators
114 // from the diagnostic list.
115 SmallVector<StoredDiagnostic, 16> StoredDiags;
116 StoredDiags.append(DiagList.begin(), DiagList.end());
117 SmallVectorImpl<StoredDiagnostic>::iterator
118 I = StoredDiags.begin(), E = StoredDiags.end();
120 if (I->getID() == diag::err_switch_into_protected_scope &&
121 isInRange(I->getLocation(), BodyRange)) {
122 handleProtectedScopeError(I, E);
129 void handleProtectedScopeError(
130 SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
131 SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
132 Transaction Trans(Pass.TA);
133 assert(DiagI->getID() == diag::err_switch_into_protected_scope);
134 SourceLocation ErrLoc = DiagI->getLocation();
135 bool handledAllNotes = true;
137 for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
139 if (!handleProtectedNote(*DiagI))
140 handledAllNotes = false;
144 Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
147 bool handleProtectedNote(const StoredDiagnostic &Diag) {
148 assert(Diag.getLevel() == DiagnosticsEngine::Note);
150 for (unsigned i = 0; i != Cases.size(); i++) {
151 CaseInfo &info = Cases[i];
152 if (isInRange(Diag.getLocation(), info.Range)) {
154 if (info.State == CaseInfo::St_Unchecked)
156 assert(info.State != CaseInfo::St_Unchecked);
158 if (info.State == CaseInfo::St_Fixed) {
159 Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
169 void tryFixing(CaseInfo &info) {
170 assert(info.State == CaseInfo::St_Unchecked);
171 if (hasVarReferencedOutside(info)) {
172 info.State = CaseInfo::St_CannotFix;
176 Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
177 Pass.TA.insert(info.Range.getEnd(), "}\n");
178 info.State = CaseInfo::St_Fixed;
181 bool hasVarReferencedOutside(CaseInfo &info) {
182 for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
183 DeclRefExpr *DRE = LocalRefs[i];
184 if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
185 !isInRange(DRE->getLocation(), info.Range))
191 bool isInRange(SourceLocation Loc, SourceRange R) {
194 return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
195 SM.isBeforeInTranslationUnit(Loc, R.getEnd());
199 } // anonymous namespace
201 void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
202 ProtectedScopeFixer Fix(BodyCtx);