1 //===--- Transformer.cpp - Transformer library implementation ---*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "clang/Tooling/Refactoring/Transformer.h"
10 #include "clang/AST/Expr.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/ASTMatchers/ASTMatchers.h"
13 #include "clang/Basic/Diagnostic.h"
14 #include "clang/Basic/SourceLocation.h"
15 #include "clang/Rewrite/Core/Rewriter.h"
16 #include "clang/Tooling/Refactoring/AtomicChange.h"
17 #include "clang/Tooling/Refactoring/SourceCode.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/Errc.h"
21 #include "llvm/Support/Error.h"
27 using namespace clang;
28 using namespace tooling;
30 using ast_matchers::MatchFinder;
31 using ast_matchers::internal::DynTypedMatcher;
32 using ast_type_traits::ASTNodeKind;
33 using ast_type_traits::DynTypedNode;
35 using llvm::StringError;
37 using MatchResult = MatchFinder::MatchResult;
39 // Did the text at this location originate in a macro definition (aka. body)?
42 // #define NESTED(x) x
43 // #define MACRO(y) { int y = NESTED(3); }
44 // if (true) MACRO(foo)
46 // The if statement expands to
48 // if (true) { int foo = 3; }
52 // For SourceManager SM, SM.isMacroArgExpansion(Loc1) and
53 // SM.isMacroArgExpansion(Loc2) are both true, but isOriginMacroBody(sm, Loc1)
54 // is false, because "foo" originated in the source file (as an argument to a
55 // macro), whereas isOriginMacroBody(SM, Loc2) is true, because "3" originated
56 // in the definition of MACRO.
57 static bool isOriginMacroBody(const clang::SourceManager &SM,
58 clang::SourceLocation Loc) {
59 while (Loc.isMacroID()) {
60 if (SM.isMacroBodyExpansion(Loc))
62 // Otherwise, it must be in an argument, so we continue searching up the
63 // invocation stack. getImmediateMacroCallerLoc() gives the location of the
64 // argument text, inside the call text.
65 Loc = SM.getImmediateMacroCallerLoc(Loc);
70 Expected<SmallVector<tooling::detail::Transformation, 1>>
71 tooling::detail::translateEdits(const MatchResult &Result,
72 llvm::ArrayRef<ASTEdit> Edits) {
73 SmallVector<tooling::detail::Transformation, 1> Transformations;
74 for (const auto &Edit : Edits) {
75 Expected<CharSourceRange> Range = Edit.TargetRange(Result);
77 return Range.takeError();
78 if (Range->isInvalid() ||
79 isOriginMacroBody(*Result.SourceManager, Range->getBegin()))
80 return SmallVector<Transformation, 0>();
81 auto Replacement = Edit.Replacement(Result);
83 return Replacement.takeError();
84 tooling::detail::Transformation T;
86 T.Replacement = std::move(*Replacement);
87 Transformations.push_back(std::move(T));
89 return Transformations;
92 ASTEdit tooling::change(RangeSelector S, TextGenerator Replacement) {
94 E.TargetRange = std::move(S);
95 E.Replacement = std::move(Replacement);
99 RewriteRule tooling::makeRule(DynTypedMatcher M, SmallVector<ASTEdit, 1> Edits,
100 TextGenerator Explanation) {
101 return RewriteRule{{RewriteRule::Case{
102 std::move(M), std::move(Edits), std::move(Explanation), {}}}};
105 void tooling::addInclude(RewriteRule &Rule, StringRef Header,
106 IncludeFormat Format) {
107 for (auto &Case : Rule.Cases)
108 Case.AddedIncludes.emplace_back(Header.str(), Format);
111 // Determines whether A is a base type of B in the class hierarchy, including
112 // the implicit relationship of Type and QualType.
113 static bool isBaseOf(ASTNodeKind A, ASTNodeKind B) {
114 static auto TypeKind = ASTNodeKind::getFromNodeKind<Type>();
115 static auto QualKind = ASTNodeKind::getFromNodeKind<QualType>();
116 /// Mimic the implicit conversions of Matcher<>.
117 /// - From Matcher<Type> to Matcher<QualType>
118 /// - From Matcher<Base> to Matcher<Derived>
119 return (A.isSame(TypeKind) && B.isSame(QualKind)) || A.isBaseOf(B);
122 // Try to find a common kind to which all of the rule's matchers can be
125 findCommonKind(const SmallVectorImpl<RewriteRule::Case> &Cases) {
126 assert(!Cases.empty() && "Rule must have at least one case.");
127 ASTNodeKind JoinKind = Cases[0].Matcher.getSupportedKind();
128 // Find a (least) Kind K, for which M.canConvertTo(K) holds, for all matchers
130 for (const auto &Case : Cases) {
131 auto K = Case.Matcher.getSupportedKind();
132 if (isBaseOf(JoinKind, K)) {
136 if (K.isSame(JoinKind) || isBaseOf(K, JoinKind))
137 // JoinKind is already the lowest.
139 // K and JoinKind are unrelated -- there is no least common kind.
140 return ASTNodeKind();
145 // Binds each rule's matcher to a unique (and deterministic) tag based on
147 static std::vector<DynTypedMatcher>
148 taggedMatchers(StringRef TagBase,
149 const SmallVectorImpl<RewriteRule::Case> &Cases) {
150 std::vector<DynTypedMatcher> Matchers;
151 Matchers.reserve(Cases.size());
153 for (const auto &Case : Cases) {
154 std::string Tag = (TagBase + Twine(count)).str();
156 auto M = Case.Matcher.tryBind(Tag);
157 assert(M && "RewriteRule matchers should be bindable.");
158 Matchers.push_back(*std::move(M));
163 // Simply gathers the contents of the various rules into a single rule. The
164 // actual work to combine these into an ordered choice is deferred to matcher
166 RewriteRule tooling::applyFirst(ArrayRef<RewriteRule> Rules) {
168 for (auto &Rule : Rules)
169 R.Cases.append(Rule.Cases.begin(), Rule.Cases.end());
173 static DynTypedMatcher joinCaseMatchers(const RewriteRule &Rule) {
174 assert(!Rule.Cases.empty() && "Rule must have at least one case.");
175 if (Rule.Cases.size() == 1)
176 return Rule.Cases[0].Matcher;
178 auto CommonKind = findCommonKind(Rule.Cases);
179 assert(!CommonKind.isNone() && "Cases must have compatible matchers.");
180 return DynTypedMatcher::constructVariadic(
181 DynTypedMatcher::VO_AnyOf, CommonKind, taggedMatchers("Tag", Rule.Cases));
184 DynTypedMatcher tooling::detail::buildMatcher(const RewriteRule &Rule) {
185 DynTypedMatcher M = joinCaseMatchers(Rule);
186 M.setAllowBind(true);
187 // `tryBind` is guaranteed to succeed, because `AllowBind` was set to true.
188 return *M.tryBind(RewriteRule::RootID);
191 // Finds the case that was "selected" -- that is, whose matcher triggered the
193 const RewriteRule::Case &
194 tooling::detail::findSelectedCase(const MatchResult &Result,
195 const RewriteRule &Rule) {
196 if (Rule.Cases.size() == 1)
197 return Rule.Cases[0];
199 auto &NodesMap = Result.Nodes.getMap();
200 for (size_t i = 0, N = Rule.Cases.size(); i < N; ++i) {
201 std::string Tag = ("Tag" + Twine(i)).str();
202 if (NodesMap.find(Tag) != NodesMap.end())
203 return Rule.Cases[i];
205 llvm_unreachable("No tag found for this rule.");
208 constexpr llvm::StringLiteral RewriteRule::RootID;
210 void Transformer::registerMatchers(MatchFinder *MatchFinder) {
211 MatchFinder->addDynamicMatcher(tooling::detail::buildMatcher(Rule), this);
214 void Transformer::run(const MatchResult &Result) {
215 if (Result.Context->getDiagnostics().hasErrorOccurred())
218 // Verify the existence and validity of the AST node that roots this rule.
219 auto &NodesMap = Result.Nodes.getMap();
220 auto Root = NodesMap.find(RewriteRule::RootID);
221 assert(Root != NodesMap.end() && "Transformation failed: missing root node.");
222 SourceLocation RootLoc = Result.SourceManager->getExpansionLoc(
223 Root->second.getSourceRange().getBegin());
224 assert(RootLoc.isValid() && "Invalid location for Root node of match.");
226 RewriteRule::Case Case = tooling::detail::findSelectedCase(Result, Rule);
227 auto Transformations = tooling::detail::translateEdits(Result, Case.Edits);
228 if (!Transformations) {
229 Consumer(Transformations.takeError());
233 if (Transformations->empty()) {
234 // No rewrite applied (but no error encountered either).
235 RootLoc.print(llvm::errs() << "note: skipping match at loc ",
236 *Result.SourceManager);
237 llvm::errs() << "\n";
241 // Record the results in the AtomicChange.
242 AtomicChange AC(*Result.SourceManager, RootLoc);
243 for (const auto &T : *Transformations) {
244 if (auto Err = AC.replace(*Result.SourceManager, T.Range, T.Replacement)) {
245 Consumer(std::move(Err));
250 for (const auto &I : Case.AddedIncludes) {
251 auto &Header = I.first;
253 case IncludeFormat::Quoted:
254 AC.addHeader(Header);
256 case IncludeFormat::Angled:
257 AC.addHeader((llvm::Twine("<") + Header + ">").str());
262 Consumer(std::move(AC));