1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
11 /// \brief Provides an action to find USR for the symbol at <offset>, as well as
12 /// all additional USRs.
14 //===----------------------------------------------------------------------===//
16 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
17 #include "clang/AST/AST.h"
18 #include "clang/AST/ASTConsumer.h"
19 #include "clang/AST/ASTContext.h"
20 #include "clang/AST/Decl.h"
21 #include "clang/AST/RecursiveASTVisitor.h"
22 #include "clang/Basic/FileManager.h"
23 #include "clang/Frontend/CompilerInstance.h"
24 #include "clang/Frontend/FrontendAction.h"
25 #include "clang/Lex/Lexer.h"
26 #include "clang/Lex/Preprocessor.h"
27 #include "clang/Tooling/CommonOptionsParser.h"
28 #include "clang/Tooling/Refactoring.h"
29 #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
30 #include "clang/Tooling/Tooling.h"
42 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
43 // If FoundDecl is a constructor or destructor, we want to instead take
44 // the Decl of the corresponding class.
45 if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
46 FoundDecl = CtorDecl->getParent();
47 else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
48 FoundDecl = DtorDecl->getParent();
49 // FIXME: (Alex L): Canonicalize implicit template instantions, just like
50 // the indexer does it.
52 // Note: please update the declaration's doc comment every time the
53 // canonicalization rules are changed.
58 // \brief NamedDeclFindingConsumer should delegate finding USRs of given Decl to
59 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
60 // Decl refers to class and adds USRs of all overridden methods if Decl refers
62 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
64 AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
65 : FoundDecl(FoundDecl), Context(Context) {}
67 std::vector<std::string> Find() {
68 // Fill OverriddenMethods and PartialSpecs storages.
69 TraverseDecl(Context.getTranslationUnitDecl());
70 if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
71 addUSRsOfOverridenFunctions(MethodDecl);
72 for (const auto &OverriddenMethod : OverriddenMethods) {
73 if (checkIfOverriddenFunctionAscends(OverriddenMethod))
74 USRSet.insert(getUSRForDecl(OverriddenMethod));
76 addUSRsOfInstantiatedMethods(MethodDecl);
77 } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
78 handleCXXRecordDecl(RecordDecl);
79 } else if (const auto *TemplateDecl =
80 dyn_cast<ClassTemplateDecl>(FoundDecl)) {
81 handleClassTemplateDecl(TemplateDecl);
83 USRSet.insert(getUSRForDecl(FoundDecl));
85 return std::vector<std::string>(USRSet.begin(), USRSet.end());
88 bool shouldVisitTemplateInstantiations() const { return true; }
90 bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
91 if (MethodDecl->isVirtual())
92 OverriddenMethods.push_back(MethodDecl);
93 if (MethodDecl->getInstantiatedFromMemberFunction())
94 InstantiatedMethods.push_back(MethodDecl);
98 bool VisitClassTemplatePartialSpecializationDecl(
99 const ClassTemplatePartialSpecializationDecl *PartialSpec) {
100 PartialSpecs.push_back(PartialSpec);
105 void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
106 RecordDecl = RecordDecl->getDefinition();
107 if (const auto *ClassTemplateSpecDecl =
108 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
109 handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
110 addUSRsOfCtorDtors(RecordDecl);
113 void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
114 for (const auto *Specialization : TemplateDecl->specializations())
115 addUSRsOfCtorDtors(Specialization);
117 for (const auto *PartialSpec : PartialSpecs) {
118 if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
119 addUSRsOfCtorDtors(PartialSpec);
121 addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
124 void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
125 RecordDecl = RecordDecl->getDefinition();
127 // Skip if the CXXRecordDecl doesn't have definition.
131 for (const auto *CtorDecl : RecordDecl->ctors())
132 USRSet.insert(getUSRForDecl(CtorDecl));
134 USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
135 USRSet.insert(getUSRForDecl(RecordDecl));
138 void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
139 USRSet.insert(getUSRForDecl(MethodDecl));
140 // Recursively visit each OverridenMethod.
141 for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
142 addUSRsOfOverridenFunctions(OverriddenMethod);
145 void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
146 // For renaming a class template method, all references of the instantiated
147 // member methods should be renamed too, so add USRs of the instantiated
148 // methods to the USR set.
149 USRSet.insert(getUSRForDecl(MethodDecl));
150 if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
151 USRSet.insert(getUSRForDecl(FT));
152 for (const auto *Method : InstantiatedMethods) {
153 if (USRSet.find(getUSRForDecl(
154 Method->getInstantiatedFromMemberFunction())) != USRSet.end())
155 USRSet.insert(getUSRForDecl(Method));
159 bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
160 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
161 if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
163 return checkIfOverriddenFunctionAscends(OverriddenMethod);
168 const Decl *FoundDecl;
170 std::set<std::string> USRSet;
171 std::vector<const CXXMethodDecl *> OverriddenMethods;
172 std::vector<const CXXMethodDecl *> InstantiatedMethods;
173 std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
177 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
178 ASTContext &Context) {
179 AdditionalUSRFinder Finder(ND, Context);
180 return Finder.Find();
183 class NamedDeclFindingConsumer : public ASTConsumer {
185 NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
186 ArrayRef<std::string> QualifiedNames,
187 std::vector<std::string> &SpellingNames,
188 std::vector<std::vector<std::string>> &USRList,
189 bool Force, bool &ErrorOccurred)
190 : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
191 SpellingNames(SpellingNames), USRList(USRList), Force(Force),
192 ErrorOccurred(ErrorOccurred) {}
195 bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
196 unsigned SymbolOffset, const std::string &QualifiedName) {
197 DiagnosticsEngine &Engine = Context.getDiagnostics();
198 const FileID MainFileID = SourceMgr.getMainFileID();
200 if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
201 ErrorOccurred = true;
202 unsigned InvalidOffset = Engine.getCustomDiagID(
203 DiagnosticsEngine::Error,
204 "SourceLocation in file %0 at offset %1 is invalid");
205 Engine.Report(SourceLocation(), InvalidOffset)
206 << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
210 const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
211 .getLocWithOffset(SymbolOffset);
212 const NamedDecl *FoundDecl = QualifiedName.empty()
213 ? getNamedDeclAt(Context, Point)
214 : getNamedDeclFor(Context, QualifiedName);
216 if (FoundDecl == nullptr) {
217 if (QualifiedName.empty()) {
218 FullSourceLoc FullLoc(Point, SourceMgr);
219 unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
220 DiagnosticsEngine::Error,
221 "clang-rename could not find symbol (offset %0)");
222 Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
223 ErrorOccurred = true;
228 SpellingNames.push_back(std::string());
229 USRList.push_back(std::vector<std::string>());
233 unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
234 DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
235 Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
236 ErrorOccurred = true;
240 FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
241 SpellingNames.push_back(FoundDecl->getNameAsString());
242 AdditionalUSRFinder Finder(FoundDecl, Context);
243 USRList.push_back(Finder.Find());
247 void HandleTranslationUnit(ASTContext &Context) override {
248 const SourceManager &SourceMgr = Context.getSourceManager();
249 for (unsigned Offset : SymbolOffsets) {
250 if (!FindSymbol(Context, SourceMgr, Offset, ""))
253 for (const std::string &QualifiedName : QualifiedNames) {
254 if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
259 ArrayRef<unsigned> SymbolOffsets;
260 ArrayRef<std::string> QualifiedNames;
261 std::vector<std::string> &SpellingNames;
262 std::vector<std::vector<std::string>> &USRList;
267 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
268 return llvm::make_unique<NamedDeclFindingConsumer>(
269 SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
273 } // end namespace tooling
274 } // end namespace clang