1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Provides an action to find USR for the symbol at <offset>, as well as
11 /// all additional USRs.
13 //===----------------------------------------------------------------------===//
15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16 #include "clang/AST/AST.h"
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/Decl.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/Basic/FileManager.h"
22 #include "clang/Frontend/CompilerInstance.h"
23 #include "clang/Frontend/FrontendAction.h"
24 #include "clang/Lex/Lexer.h"
25 #include "clang/Lex/Preprocessor.h"
26 #include "clang/Tooling/CommonOptionsParser.h"
27 #include "clang/Tooling/Refactoring.h"
28 #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29 #include "clang/Tooling/Tooling.h"
41 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42 // If FoundDecl is a constructor or destructor, we want to instead take
43 // the Decl of the corresponding class.
44 if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
45 FoundDecl = CtorDecl->getParent();
46 else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
47 FoundDecl = DtorDecl->getParent();
48 // FIXME: (Alex L): Canonicalize implicit template instantions, just like
49 // the indexer does it.
51 // Note: please update the declaration's doc comment every time the
52 // canonicalization rules are changed.
57 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to
58 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
59 // Decl refers to class and adds USRs of all overridden methods if Decl refers
61 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
63 AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
64 : FoundDecl(FoundDecl), Context(Context) {}
66 std::vector<std::string> Find() {
67 // Fill OverriddenMethods and PartialSpecs storages.
68 TraverseDecl(Context.getTranslationUnitDecl());
69 if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
70 addUSRsOfOverridenFunctions(MethodDecl);
71 for (const auto &OverriddenMethod : OverriddenMethods) {
72 if (checkIfOverriddenFunctionAscends(OverriddenMethod))
73 USRSet.insert(getUSRForDecl(OverriddenMethod));
75 addUSRsOfInstantiatedMethods(MethodDecl);
76 } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
77 handleCXXRecordDecl(RecordDecl);
78 } else if (const auto *TemplateDecl =
79 dyn_cast<ClassTemplateDecl>(FoundDecl)) {
80 handleClassTemplateDecl(TemplateDecl);
82 USRSet.insert(getUSRForDecl(FoundDecl));
84 return std::vector<std::string>(USRSet.begin(), USRSet.end());
87 bool shouldVisitTemplateInstantiations() const { return true; }
89 bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
90 if (MethodDecl->isVirtual())
91 OverriddenMethods.push_back(MethodDecl);
92 if (MethodDecl->getInstantiatedFromMemberFunction())
93 InstantiatedMethods.push_back(MethodDecl);
97 bool VisitClassTemplatePartialSpecializationDecl(
98 const ClassTemplatePartialSpecializationDecl *PartialSpec) {
99 PartialSpecs.push_back(PartialSpec);
104 void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
105 RecordDecl = RecordDecl->getDefinition();
106 if (const auto *ClassTemplateSpecDecl =
107 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
108 handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
109 addUSRsOfCtorDtors(RecordDecl);
112 void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
113 for (const auto *Specialization : TemplateDecl->specializations())
114 addUSRsOfCtorDtors(Specialization);
116 for (const auto *PartialSpec : PartialSpecs) {
117 if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
118 addUSRsOfCtorDtors(PartialSpec);
120 addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
123 void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
124 RecordDecl = RecordDecl->getDefinition();
126 // Skip if the CXXRecordDecl doesn't have definition.
130 for (const auto *CtorDecl : RecordDecl->ctors())
131 USRSet.insert(getUSRForDecl(CtorDecl));
133 USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
134 USRSet.insert(getUSRForDecl(RecordDecl));
137 void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
138 USRSet.insert(getUSRForDecl(MethodDecl));
139 // Recursively visit each OverridenMethod.
140 for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
141 addUSRsOfOverridenFunctions(OverriddenMethod);
144 void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
145 // For renaming a class template method, all references of the instantiated
146 // member methods should be renamed too, so add USRs of the instantiated
147 // methods to the USR set.
148 USRSet.insert(getUSRForDecl(MethodDecl));
149 if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
150 USRSet.insert(getUSRForDecl(FT));
151 for (const auto *Method : InstantiatedMethods) {
152 if (USRSet.find(getUSRForDecl(
153 Method->getInstantiatedFromMemberFunction())) != USRSet.end())
154 USRSet.insert(getUSRForDecl(Method));
158 bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
159 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
160 if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
162 return checkIfOverriddenFunctionAscends(OverriddenMethod);
167 const Decl *FoundDecl;
169 std::set<std::string> USRSet;
170 std::vector<const CXXMethodDecl *> OverriddenMethods;
171 std::vector<const CXXMethodDecl *> InstantiatedMethods;
172 std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
176 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
177 ASTContext &Context) {
178 AdditionalUSRFinder Finder(ND, Context);
179 return Finder.Find();
182 class NamedDeclFindingConsumer : public ASTConsumer {
184 NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
185 ArrayRef<std::string> QualifiedNames,
186 std::vector<std::string> &SpellingNames,
187 std::vector<std::vector<std::string>> &USRList,
188 bool Force, bool &ErrorOccurred)
189 : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
190 SpellingNames(SpellingNames), USRList(USRList), Force(Force),
191 ErrorOccurred(ErrorOccurred) {}
194 bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
195 unsigned SymbolOffset, const std::string &QualifiedName) {
196 DiagnosticsEngine &Engine = Context.getDiagnostics();
197 const FileID MainFileID = SourceMgr.getMainFileID();
199 if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
200 ErrorOccurred = true;
201 unsigned InvalidOffset = Engine.getCustomDiagID(
202 DiagnosticsEngine::Error,
203 "SourceLocation in file %0 at offset %1 is invalid");
204 Engine.Report(SourceLocation(), InvalidOffset)
205 << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
209 const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
210 .getLocWithOffset(SymbolOffset);
211 const NamedDecl *FoundDecl = QualifiedName.empty()
212 ? getNamedDeclAt(Context, Point)
213 : getNamedDeclFor(Context, QualifiedName);
215 if (FoundDecl == nullptr) {
216 if (QualifiedName.empty()) {
217 FullSourceLoc FullLoc(Point, SourceMgr);
218 unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
219 DiagnosticsEngine::Error,
220 "clang-rename could not find symbol (offset %0)");
221 Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
222 ErrorOccurred = true;
227 SpellingNames.push_back(std::string());
228 USRList.push_back(std::vector<std::string>());
232 unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
233 DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
234 Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
235 ErrorOccurred = true;
239 FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
240 SpellingNames.push_back(FoundDecl->getNameAsString());
241 AdditionalUSRFinder Finder(FoundDecl, Context);
242 USRList.push_back(Finder.Find());
246 void HandleTranslationUnit(ASTContext &Context) override {
247 const SourceManager &SourceMgr = Context.getSourceManager();
248 for (unsigned Offset : SymbolOffsets) {
249 if (!FindSymbol(Context, SourceMgr, Offset, ""))
252 for (const std::string &QualifiedName : QualifiedNames) {
253 if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
258 ArrayRef<unsigned> SymbolOffsets;
259 ArrayRef<std::string> QualifiedNames;
260 std::vector<std::string> &SpellingNames;
261 std::vector<std::vector<std::string>> &USRList;
266 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
267 return llvm::make_unique<NamedDeclFindingConsumer>(
268 SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
272 } // end namespace tooling
273 } // end namespace clang