1 1.1 joerg //===--- RefactoringCallbacks.cpp - Structural query framework ------------===// 2 1.1 joerg // 3 1.1 joerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 1.1 joerg // See https://llvm.org/LICENSE.txt for license information. 5 1.1 joerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 1.1 joerg // 7 1.1 joerg //===----------------------------------------------------------------------===// 8 1.1 joerg // 9 1.1 joerg // 10 1.1 joerg //===----------------------------------------------------------------------===// 11 1.1 joerg #include "clang/Tooling/RefactoringCallbacks.h" 12 1.1 joerg #include "clang/ASTMatchers/ASTMatchFinder.h" 13 1.1 joerg #include "clang/Basic/SourceLocation.h" 14 1.1 joerg #include "clang/Lex/Lexer.h" 15 1.1 joerg 16 1.1 joerg using llvm::StringError; 17 1.1 joerg using llvm::make_error; 18 1.1 joerg 19 1.1 joerg namespace clang { 20 1.1 joerg namespace tooling { 21 1.1 joerg 22 1.1 joerg RefactoringCallback::RefactoringCallback() {} 23 1.1 joerg tooling::Replacements &RefactoringCallback::getReplacements() { 24 1.1 joerg return Replace; 25 1.1 joerg } 26 1.1 joerg 27 1.1 joerg ASTMatchRefactorer::ASTMatchRefactorer( 28 1.1 joerg std::map<std::string, Replacements> &FileToReplaces) 29 1.1 joerg : FileToReplaces(FileToReplaces) {} 30 1.1 joerg 31 1.1 joerg void ASTMatchRefactorer::addDynamicMatcher( 32 1.1 joerg const ast_matchers::internal::DynTypedMatcher &Matcher, 33 1.1 joerg RefactoringCallback *Callback) { 34 1.1 joerg MatchFinder.addDynamicMatcher(Matcher, Callback); 35 1.1 joerg Callbacks.push_back(Callback); 36 1.1 joerg } 37 1.1 joerg 38 1.1 joerg class RefactoringASTConsumer : public ASTConsumer { 39 1.1 joerg public: 40 1.1 joerg explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring) 41 1.1 joerg : Refactoring(Refactoring) {} 42 1.1 joerg 43 1.1 joerg void HandleTranslationUnit(ASTContext &Context) override { 44 1.1 joerg // The ASTMatchRefactorer is re-used between translation units. 45 1.1 joerg // Clear the matchers so that each Replacement is only emitted once. 46 1.1 joerg for (const auto &Callback : Refactoring.Callbacks) { 47 1.1 joerg Callback->getReplacements().clear(); 48 1.1 joerg } 49 1.1 joerg Refactoring.MatchFinder.matchAST(Context); 50 1.1 joerg for (const auto &Callback : Refactoring.Callbacks) { 51 1.1 joerg for (const auto &Replacement : Callback->getReplacements()) { 52 1.1 joerg llvm::Error Err = 53 1.1.1.2 joerg Refactoring.FileToReplaces[std::string(Replacement.getFilePath())] 54 1.1.1.2 joerg .add(Replacement); 55 1.1 joerg if (Err) { 56 1.1 joerg llvm::errs() << "Skipping replacement " << Replacement.toString() 57 1.1 joerg << " due to this error:\n" 58 1.1 joerg << toString(std::move(Err)) << "\n"; 59 1.1 joerg } 60 1.1 joerg } 61 1.1 joerg } 62 1.1 joerg } 63 1.1 joerg 64 1.1 joerg private: 65 1.1 joerg ASTMatchRefactorer &Refactoring; 66 1.1 joerg }; 67 1.1 joerg 68 1.1 joerg std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() { 69 1.1 joerg return std::make_unique<RefactoringASTConsumer>(*this); 70 1.1 joerg } 71 1.1 joerg 72 1.1 joerg static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From, 73 1.1 joerg StringRef Text) { 74 1.1 joerg return tooling::Replacement( 75 1.1 joerg Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text); 76 1.1 joerg } 77 1.1 joerg static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From, 78 1.1 joerg const Stmt &To) { 79 1.1 joerg return replaceStmtWithText( 80 1.1 joerg Sources, From, 81 1.1 joerg Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()), 82 1.1 joerg Sources, LangOptions())); 83 1.1 joerg } 84 1.1 joerg 85 1.1 joerg ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) 86 1.1.1.2 joerg : FromId(std::string(FromId)), ToText(std::string(ToText)) {} 87 1.1 joerg 88 1.1 joerg void ReplaceStmtWithText::run( 89 1.1 joerg const ast_matchers::MatchFinder::MatchResult &Result) { 90 1.1 joerg if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) { 91 1.1 joerg auto Err = Replace.add(tooling::Replacement( 92 1.1 joerg *Result.SourceManager, 93 1.1 joerg CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText)); 94 1.1 joerg // FIXME: better error handling. For now, just print error message in the 95 1.1 joerg // release version. 96 1.1 joerg if (Err) { 97 1.1 joerg llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 98 1.1 joerg assert(false); 99 1.1 joerg } 100 1.1 joerg } 101 1.1 joerg } 102 1.1 joerg 103 1.1 joerg ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId) 104 1.1.1.2 joerg : FromId(std::string(FromId)), ToId(std::string(ToId)) {} 105 1.1 joerg 106 1.1 joerg void ReplaceStmtWithStmt::run( 107 1.1 joerg const ast_matchers::MatchFinder::MatchResult &Result) { 108 1.1 joerg const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId); 109 1.1 joerg const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId); 110 1.1 joerg if (FromMatch && ToMatch) { 111 1.1 joerg auto Err = Replace.add( 112 1.1 joerg replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch)); 113 1.1 joerg // FIXME: better error handling. For now, just print error message in the 114 1.1 joerg // release version. 115 1.1 joerg if (Err) { 116 1.1 joerg llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 117 1.1 joerg assert(false); 118 1.1 joerg } 119 1.1 joerg } 120 1.1 joerg } 121 1.1 joerg 122 1.1 joerg ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id, 123 1.1 joerg bool PickTrueBranch) 124 1.1.1.2 joerg : Id(std::string(Id)), PickTrueBranch(PickTrueBranch) {} 125 1.1 joerg 126 1.1 joerg void ReplaceIfStmtWithItsBody::run( 127 1.1 joerg const ast_matchers::MatchFinder::MatchResult &Result) { 128 1.1 joerg if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) { 129 1.1 joerg const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse(); 130 1.1 joerg if (Body) { 131 1.1 joerg auto Err = 132 1.1 joerg Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body)); 133 1.1 joerg // FIXME: better error handling. For now, just print error message in the 134 1.1 joerg // release version. 135 1.1 joerg if (Err) { 136 1.1 joerg llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 137 1.1 joerg assert(false); 138 1.1 joerg } 139 1.1 joerg } else if (!PickTrueBranch) { 140 1.1 joerg // If we want to use the 'else'-branch, but it doesn't exist, delete 141 1.1 joerg // the whole 'if'. 142 1.1 joerg auto Err = 143 1.1 joerg Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, "")); 144 1.1 joerg // FIXME: better error handling. For now, just print error message in the 145 1.1 joerg // release version. 146 1.1 joerg if (Err) { 147 1.1 joerg llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 148 1.1 joerg assert(false); 149 1.1 joerg } 150 1.1 joerg } 151 1.1 joerg } 152 1.1 joerg } 153 1.1 joerg 154 1.1 joerg ReplaceNodeWithTemplate::ReplaceNodeWithTemplate( 155 1.1 joerg llvm::StringRef FromId, std::vector<TemplateElement> Template) 156 1.1.1.2 joerg : FromId(std::string(FromId)), Template(std::move(Template)) {} 157 1.1 joerg 158 1.1 joerg llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>> 159 1.1 joerg ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) { 160 1.1 joerg std::vector<TemplateElement> ParsedTemplate; 161 1.1 joerg for (size_t Index = 0; Index < ToTemplate.size();) { 162 1.1 joerg if (ToTemplate[Index] == '$') { 163 1.1 joerg if (ToTemplate.substr(Index, 2) == "$$") { 164 1.1 joerg Index += 2; 165 1.1 joerg ParsedTemplate.push_back( 166 1.1 joerg TemplateElement{TemplateElement::Literal, "$"}); 167 1.1 joerg } else if (ToTemplate.substr(Index, 2) == "${") { 168 1.1 joerg size_t EndOfIdentifier = ToTemplate.find("}", Index); 169 1.1 joerg if (EndOfIdentifier == std::string::npos) { 170 1.1 joerg return make_error<StringError>( 171 1.1 joerg "Unterminated ${...} in replacement template near " + 172 1.1 joerg ToTemplate.substr(Index), 173 1.1 joerg llvm::inconvertibleErrorCode()); 174 1.1 joerg } 175 1.1.1.2 joerg std::string SourceNodeName = std::string( 176 1.1.1.2 joerg ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2)); 177 1.1 joerg ParsedTemplate.push_back( 178 1.1 joerg TemplateElement{TemplateElement::Identifier, SourceNodeName}); 179 1.1 joerg Index = EndOfIdentifier + 1; 180 1.1 joerg } else { 181 1.1 joerg return make_error<StringError>( 182 1.1 joerg "Invalid $ in replacement template near " + 183 1.1 joerg ToTemplate.substr(Index), 184 1.1 joerg llvm::inconvertibleErrorCode()); 185 1.1 joerg } 186 1.1 joerg } else { 187 1.1 joerg size_t NextIndex = ToTemplate.find('$', Index + 1); 188 1.1.1.2 joerg ParsedTemplate.push_back(TemplateElement{ 189 1.1.1.2 joerg TemplateElement::Literal, 190 1.1.1.2 joerg std::string(ToTemplate.substr(Index, NextIndex - Index))}); 191 1.1 joerg Index = NextIndex; 192 1.1 joerg } 193 1.1 joerg } 194 1.1 joerg return std::unique_ptr<ReplaceNodeWithTemplate>( 195 1.1 joerg new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate))); 196 1.1 joerg } 197 1.1 joerg 198 1.1 joerg void ReplaceNodeWithTemplate::run( 199 1.1 joerg const ast_matchers::MatchFinder::MatchResult &Result) { 200 1.1 joerg const auto &NodeMap = Result.Nodes.getMap(); 201 1.1 joerg 202 1.1 joerg std::string ToText; 203 1.1 joerg for (const auto &Element : Template) { 204 1.1 joerg switch (Element.Type) { 205 1.1 joerg case TemplateElement::Literal: 206 1.1 joerg ToText += Element.Value; 207 1.1 joerg break; 208 1.1 joerg case TemplateElement::Identifier: { 209 1.1 joerg auto NodeIter = NodeMap.find(Element.Value); 210 1.1 joerg if (NodeIter == NodeMap.end()) { 211 1.1 joerg llvm::errs() << "Node " << Element.Value 212 1.1 joerg << " used in replacement template not bound in Matcher \n"; 213 1.1 joerg llvm::report_fatal_error("Unbound node in replacement template."); 214 1.1 joerg } 215 1.1 joerg CharSourceRange Source = 216 1.1 joerg CharSourceRange::getTokenRange(NodeIter->second.getSourceRange()); 217 1.1 joerg ToText += Lexer::getSourceText(Source, *Result.SourceManager, 218 1.1 joerg Result.Context->getLangOpts()); 219 1.1 joerg break; 220 1.1 joerg } 221 1.1 joerg } 222 1.1 joerg } 223 1.1 joerg if (NodeMap.count(FromId) == 0) { 224 1.1 joerg llvm::errs() << "Node to be replaced " << FromId 225 1.1 joerg << " not bound in query.\n"; 226 1.1 joerg llvm::report_fatal_error("FromId node not bound in MatchResult"); 227 1.1 joerg } 228 1.1 joerg auto Replacement = 229 1.1 joerg tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText, 230 1.1 joerg Result.Context->getLangOpts()); 231 1.1 joerg llvm::Error Err = Replace.add(Replacement); 232 1.1 joerg if (Err) { 233 1.1 joerg llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() 234 1.1 joerg << "! " << llvm::toString(std::move(Err)) << "\n"; 235 1.1 joerg llvm::report_fatal_error("Replacement failed"); 236 1.1 joerg } 237 1.1 joerg } 238 1.1 joerg 239 1.1 joerg } // end namespace tooling 240 1.1 joerg } // end namespace clang 241