1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Provides an action to find USR for the symbol at <offset>, as well as
11 /// all additional USRs.
13 //===----------------------------------------------------------------------===//
15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16 #include "clang/AST/AST.h"
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/Decl.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/Basic/FileManager.h"
22 #include "clang/Frontend/CompilerInstance.h"
23 #include "clang/Frontend/FrontendAction.h"
24 #include "clang/Lex/Lexer.h"
25 #include "clang/Lex/Preprocessor.h"
26 #include "clang/Tooling/CommonOptionsParser.h"
27 #include "clang/Tooling/Refactoring.h"
28 #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29 #include "clang/Tooling/Tooling.h"
41 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42 // If FoundDecl is a constructor or destructor, we want to instead take
43 // the Decl of the corresponding class.
44 if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
45 FoundDecl = CtorDecl->getParent();
46 else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
47 FoundDecl = DtorDecl->getParent();
48 // FIXME: (Alex L): Canonicalize implicit template instantions, just like
49 // the indexer does it.
51 // Note: please update the declaration's doc comment every time the
52 // canonicalization rules are changed.
57 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to
58 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
59 // Decl refers to class and adds USRs of all overridden methods if Decl refers
61 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
63 AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
64 : FoundDecl(FoundDecl), Context(Context) {}
66 std::vector<std::string> Find() {
67 // Fill OverriddenMethods and PartialSpecs storages.
68 TraverseDecl(Context.getTranslationUnitDecl());
69 if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
70 addUSRsOfOverridenFunctions(MethodDecl);
71 for (const auto &OverriddenMethod : OverriddenMethods) {
72 if (checkIfOverriddenFunctionAscends(OverriddenMethod))
73 USRSet.insert(getUSRForDecl(OverriddenMethod));
75 addUSRsOfInstantiatedMethods(MethodDecl);
76 } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
77 handleCXXRecordDecl(RecordDecl);
78 } else if (const auto *TemplateDecl =
79 dyn_cast<ClassTemplateDecl>(FoundDecl)) {
80 handleClassTemplateDecl(TemplateDecl);
82 USRSet.insert(getUSRForDecl(FoundDecl));
84 return std::vector<std::string>(USRSet.begin(), USRSet.end());
87 bool shouldVisitTemplateInstantiations() const { return true; }
89 bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
90 if (MethodDecl->isVirtual())
91 OverriddenMethods.push_back(MethodDecl);
92 if (MethodDecl->getInstantiatedFromMemberFunction())
93 InstantiatedMethods.push_back(MethodDecl);
97 bool VisitClassTemplatePartialSpecializationDecl(
98 const ClassTemplatePartialSpecializationDecl *PartialSpec) {
99 PartialSpecs.push_back(PartialSpec);
104 void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
105 if (!RecordDecl->getDefinition()) {
106 USRSet.insert(getUSRForDecl(RecordDecl));
109 RecordDecl = RecordDecl->getDefinition();
110 if (const auto *ClassTemplateSpecDecl =
111 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
112 handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
113 addUSRsOfCtorDtors(RecordDecl);
116 void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
117 for (const auto *Specialization : TemplateDecl->specializations())
118 addUSRsOfCtorDtors(Specialization);
120 for (const auto *PartialSpec : PartialSpecs) {
121 if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
122 addUSRsOfCtorDtors(PartialSpec);
124 addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
127 void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
128 RecordDecl = RecordDecl->getDefinition();
130 // Skip if the CXXRecordDecl doesn't have definition.
134 for (const auto *CtorDecl : RecordDecl->ctors())
135 USRSet.insert(getUSRForDecl(CtorDecl));
137 USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
138 USRSet.insert(getUSRForDecl(RecordDecl));
141 void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
142 USRSet.insert(getUSRForDecl(MethodDecl));
143 // Recursively visit each OverridenMethod.
144 for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
145 addUSRsOfOverridenFunctions(OverriddenMethod);
148 void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
149 // For renaming a class template method, all references of the instantiated
150 // member methods should be renamed too, so add USRs of the instantiated
151 // methods to the USR set.
152 USRSet.insert(getUSRForDecl(MethodDecl));
153 if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
154 USRSet.insert(getUSRForDecl(FT));
155 for (const auto *Method : InstantiatedMethods) {
156 if (USRSet.find(getUSRForDecl(
157 Method->getInstantiatedFromMemberFunction())) != USRSet.end())
158 USRSet.insert(getUSRForDecl(Method));
162 bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
163 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
164 if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
166 return checkIfOverriddenFunctionAscends(OverriddenMethod);
171 const Decl *FoundDecl;
173 std::set<std::string> USRSet;
174 std::vector<const CXXMethodDecl *> OverriddenMethods;
175 std::vector<const CXXMethodDecl *> InstantiatedMethods;
176 std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
180 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
181 ASTContext &Context) {
182 AdditionalUSRFinder Finder(ND, Context);
183 return Finder.Find();
186 class NamedDeclFindingConsumer : public ASTConsumer {
188 NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
189 ArrayRef<std::string> QualifiedNames,
190 std::vector<std::string> &SpellingNames,
191 std::vector<std::vector<std::string>> &USRList,
192 bool Force, bool &ErrorOccurred)
193 : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
194 SpellingNames(SpellingNames), USRList(USRList), Force(Force),
195 ErrorOccurred(ErrorOccurred) {}
198 bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
199 unsigned SymbolOffset, const std::string &QualifiedName) {
200 DiagnosticsEngine &Engine = Context.getDiagnostics();
201 const FileID MainFileID = SourceMgr.getMainFileID();
203 if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
204 ErrorOccurred = true;
205 unsigned InvalidOffset = Engine.getCustomDiagID(
206 DiagnosticsEngine::Error,
207 "SourceLocation in file %0 at offset %1 is invalid");
208 Engine.Report(SourceLocation(), InvalidOffset)
209 << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
213 const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
214 .getLocWithOffset(SymbolOffset);
215 const NamedDecl *FoundDecl = QualifiedName.empty()
216 ? getNamedDeclAt(Context, Point)
217 : getNamedDeclFor(Context, QualifiedName);
219 if (FoundDecl == nullptr) {
220 if (QualifiedName.empty()) {
221 FullSourceLoc FullLoc(Point, SourceMgr);
222 unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
223 DiagnosticsEngine::Error,
224 "clang-rename could not find symbol (offset %0)");
225 Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
226 ErrorOccurred = true;
231 SpellingNames.push_back(std::string());
232 USRList.push_back(std::vector<std::string>());
236 unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
237 DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
238 Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
239 ErrorOccurred = true;
243 FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
244 SpellingNames.push_back(FoundDecl->getNameAsString());
245 AdditionalUSRFinder Finder(FoundDecl, Context);
246 USRList.push_back(Finder.Find());
250 void HandleTranslationUnit(ASTContext &Context) override {
251 const SourceManager &SourceMgr = Context.getSourceManager();
252 for (unsigned Offset : SymbolOffsets) {
253 if (!FindSymbol(Context, SourceMgr, Offset, ""))
256 for (const std::string &QualifiedName : QualifiedNames) {
257 if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
262 ArrayRef<unsigned> SymbolOffsets;
263 ArrayRef<std::string> QualifiedNames;
264 std::vector<std::string> &SpellingNames;
265 std::vector<std::vector<std::string>> &USRList;
270 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
271 return std::make_unique<NamedDeclFindingConsumer>(
272 SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
276 } // end namespace tooling
277 } // end namespace clang