Home | History | Annotate | Line # | Download | only in Tooling
      1      1.1  joerg //===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
      2      1.1  joerg //
      3      1.1  joerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4      1.1  joerg // See https://llvm.org/LICENSE.txt for license information.
      5      1.1  joerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6      1.1  joerg //
      7      1.1  joerg //===----------------------------------------------------------------------===//
      8      1.1  joerg //
      9      1.1  joerg //
     10      1.1  joerg //===----------------------------------------------------------------------===//
     11      1.1  joerg #include "clang/Tooling/RefactoringCallbacks.h"
     12      1.1  joerg #include "clang/ASTMatchers/ASTMatchFinder.h"
     13      1.1  joerg #include "clang/Basic/SourceLocation.h"
     14      1.1  joerg #include "clang/Lex/Lexer.h"
     15      1.1  joerg 
     16      1.1  joerg using llvm::StringError;
     17      1.1  joerg using llvm::make_error;
     18      1.1  joerg 
     19      1.1  joerg namespace clang {
     20      1.1  joerg namespace tooling {
     21      1.1  joerg 
     22      1.1  joerg RefactoringCallback::RefactoringCallback() {}
     23      1.1  joerg tooling::Replacements &RefactoringCallback::getReplacements() {
     24      1.1  joerg   return Replace;
     25      1.1  joerg }
     26      1.1  joerg 
     27      1.1  joerg ASTMatchRefactorer::ASTMatchRefactorer(
     28      1.1  joerg     std::map<std::string, Replacements> &FileToReplaces)
     29      1.1  joerg     : FileToReplaces(FileToReplaces) {}
     30      1.1  joerg 
     31      1.1  joerg void ASTMatchRefactorer::addDynamicMatcher(
     32      1.1  joerg     const ast_matchers::internal::DynTypedMatcher &Matcher,
     33      1.1  joerg     RefactoringCallback *Callback) {
     34      1.1  joerg   MatchFinder.addDynamicMatcher(Matcher, Callback);
     35      1.1  joerg   Callbacks.push_back(Callback);
     36      1.1  joerg }
     37      1.1  joerg 
     38      1.1  joerg class RefactoringASTConsumer : public ASTConsumer {
     39      1.1  joerg public:
     40      1.1  joerg   explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
     41      1.1  joerg       : Refactoring(Refactoring) {}
     42      1.1  joerg 
     43      1.1  joerg   void HandleTranslationUnit(ASTContext &Context) override {
     44      1.1  joerg     // The ASTMatchRefactorer is re-used between translation units.
     45      1.1  joerg     // Clear the matchers so that each Replacement is only emitted once.
     46      1.1  joerg     for (const auto &Callback : Refactoring.Callbacks) {
     47      1.1  joerg       Callback->getReplacements().clear();
     48      1.1  joerg     }
     49      1.1  joerg     Refactoring.MatchFinder.matchAST(Context);
     50      1.1  joerg     for (const auto &Callback : Refactoring.Callbacks) {
     51      1.1  joerg       for (const auto &Replacement : Callback->getReplacements()) {
     52      1.1  joerg         llvm::Error Err =
     53  1.1.1.2  joerg             Refactoring.FileToReplaces[std::string(Replacement.getFilePath())]
     54  1.1.1.2  joerg                 .add(Replacement);
     55      1.1  joerg         if (Err) {
     56      1.1  joerg           llvm::errs() << "Skipping replacement " << Replacement.toString()
     57      1.1  joerg                        << " due to this error:\n"
     58      1.1  joerg                        << toString(std::move(Err)) << "\n";
     59      1.1  joerg         }
     60      1.1  joerg       }
     61      1.1  joerg     }
     62      1.1  joerg   }
     63      1.1  joerg 
     64      1.1  joerg private:
     65      1.1  joerg   ASTMatchRefactorer &Refactoring;
     66      1.1  joerg };
     67      1.1  joerg 
     68      1.1  joerg std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
     69      1.1  joerg   return std::make_unique<RefactoringASTConsumer>(*this);
     70      1.1  joerg }
     71      1.1  joerg 
     72      1.1  joerg static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
     73      1.1  joerg                                        StringRef Text) {
     74      1.1  joerg   return tooling::Replacement(
     75      1.1  joerg       Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
     76      1.1  joerg }
     77      1.1  joerg static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
     78      1.1  joerg                                        const Stmt &To) {
     79      1.1  joerg   return replaceStmtWithText(
     80      1.1  joerg       Sources, From,
     81      1.1  joerg       Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
     82      1.1  joerg                            Sources, LangOptions()));
     83      1.1  joerg }
     84      1.1  joerg 
     85      1.1  joerg ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
     86  1.1.1.2  joerg     : FromId(std::string(FromId)), ToText(std::string(ToText)) {}
     87      1.1  joerg 
     88      1.1  joerg void ReplaceStmtWithText::run(
     89      1.1  joerg     const ast_matchers::MatchFinder::MatchResult &Result) {
     90      1.1  joerg   if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) {
     91      1.1  joerg     auto Err = Replace.add(tooling::Replacement(
     92      1.1  joerg         *Result.SourceManager,
     93      1.1  joerg         CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
     94      1.1  joerg     // FIXME: better error handling. For now, just print error message in the
     95      1.1  joerg     // release version.
     96      1.1  joerg     if (Err) {
     97      1.1  joerg       llvm::errs() << llvm::toString(std::move(Err)) << "\n";
     98      1.1  joerg       assert(false);
     99      1.1  joerg     }
    100      1.1  joerg   }
    101      1.1  joerg }
    102      1.1  joerg 
    103      1.1  joerg ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId)
    104  1.1.1.2  joerg     : FromId(std::string(FromId)), ToId(std::string(ToId)) {}
    105      1.1  joerg 
    106      1.1  joerg void ReplaceStmtWithStmt::run(
    107      1.1  joerg     const ast_matchers::MatchFinder::MatchResult &Result) {
    108      1.1  joerg   const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId);
    109      1.1  joerg   const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId);
    110      1.1  joerg   if (FromMatch && ToMatch) {
    111      1.1  joerg     auto Err = Replace.add(
    112      1.1  joerg         replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
    113      1.1  joerg     // FIXME: better error handling. For now, just print error message in the
    114      1.1  joerg     // release version.
    115      1.1  joerg     if (Err) {
    116      1.1  joerg       llvm::errs() << llvm::toString(std::move(Err)) << "\n";
    117      1.1  joerg       assert(false);
    118      1.1  joerg     }
    119      1.1  joerg   }
    120      1.1  joerg }
    121      1.1  joerg 
    122      1.1  joerg ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
    123      1.1  joerg                                                    bool PickTrueBranch)
    124  1.1.1.2  joerg     : Id(std::string(Id)), PickTrueBranch(PickTrueBranch) {}
    125      1.1  joerg 
    126      1.1  joerg void ReplaceIfStmtWithItsBody::run(
    127      1.1  joerg     const ast_matchers::MatchFinder::MatchResult &Result) {
    128      1.1  joerg   if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) {
    129      1.1  joerg     const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
    130      1.1  joerg     if (Body) {
    131      1.1  joerg       auto Err =
    132      1.1  joerg           Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
    133      1.1  joerg       // FIXME: better error handling. For now, just print error message in the
    134      1.1  joerg       // release version.
    135      1.1  joerg       if (Err) {
    136      1.1  joerg         llvm::errs() << llvm::toString(std::move(Err)) << "\n";
    137      1.1  joerg         assert(false);
    138      1.1  joerg       }
    139      1.1  joerg     } else if (!PickTrueBranch) {
    140      1.1  joerg       // If we want to use the 'else'-branch, but it doesn't exist, delete
    141      1.1  joerg       // the whole 'if'.
    142      1.1  joerg       auto Err =
    143      1.1  joerg           Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
    144      1.1  joerg       // FIXME: better error handling. For now, just print error message in the
    145      1.1  joerg       // release version.
    146      1.1  joerg       if (Err) {
    147      1.1  joerg         llvm::errs() << llvm::toString(std::move(Err)) << "\n";
    148      1.1  joerg         assert(false);
    149      1.1  joerg       }
    150      1.1  joerg     }
    151      1.1  joerg   }
    152      1.1  joerg }
    153      1.1  joerg 
    154      1.1  joerg ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
    155      1.1  joerg     llvm::StringRef FromId, std::vector<TemplateElement> Template)
    156  1.1.1.2  joerg     : FromId(std::string(FromId)), Template(std::move(Template)) {}
    157      1.1  joerg 
    158      1.1  joerg llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
    159      1.1  joerg ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
    160      1.1  joerg   std::vector<TemplateElement> ParsedTemplate;
    161      1.1  joerg   for (size_t Index = 0; Index < ToTemplate.size();) {
    162      1.1  joerg     if (ToTemplate[Index] == '$') {
    163      1.1  joerg       if (ToTemplate.substr(Index, 2) == "$$") {
    164      1.1  joerg         Index += 2;
    165      1.1  joerg         ParsedTemplate.push_back(
    166      1.1  joerg             TemplateElement{TemplateElement::Literal, "$"});
    167      1.1  joerg       } else if (ToTemplate.substr(Index, 2) == "${") {
    168      1.1  joerg         size_t EndOfIdentifier = ToTemplate.find("}", Index);
    169      1.1  joerg         if (EndOfIdentifier == std::string::npos) {
    170      1.1  joerg           return make_error<StringError>(
    171      1.1  joerg               "Unterminated ${...} in replacement template near " +
    172      1.1  joerg                   ToTemplate.substr(Index),
    173      1.1  joerg               llvm::inconvertibleErrorCode());
    174      1.1  joerg         }
    175  1.1.1.2  joerg         std::string SourceNodeName = std::string(
    176  1.1.1.2  joerg             ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2));
    177      1.1  joerg         ParsedTemplate.push_back(
    178      1.1  joerg             TemplateElement{TemplateElement::Identifier, SourceNodeName});
    179      1.1  joerg         Index = EndOfIdentifier + 1;
    180      1.1  joerg       } else {
    181      1.1  joerg         return make_error<StringError>(
    182      1.1  joerg             "Invalid $ in replacement template near " +
    183      1.1  joerg                 ToTemplate.substr(Index),
    184      1.1  joerg             llvm::inconvertibleErrorCode());
    185      1.1  joerg       }
    186      1.1  joerg     } else {
    187      1.1  joerg       size_t NextIndex = ToTemplate.find('$', Index + 1);
    188  1.1.1.2  joerg       ParsedTemplate.push_back(TemplateElement{
    189  1.1.1.2  joerg           TemplateElement::Literal,
    190  1.1.1.2  joerg           std::string(ToTemplate.substr(Index, NextIndex - Index))});
    191      1.1  joerg       Index = NextIndex;
    192      1.1  joerg     }
    193      1.1  joerg   }
    194      1.1  joerg   return std::unique_ptr<ReplaceNodeWithTemplate>(
    195      1.1  joerg       new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
    196      1.1  joerg }
    197      1.1  joerg 
    198      1.1  joerg void ReplaceNodeWithTemplate::run(
    199      1.1  joerg     const ast_matchers::MatchFinder::MatchResult &Result) {
    200      1.1  joerg   const auto &NodeMap = Result.Nodes.getMap();
    201      1.1  joerg 
    202      1.1  joerg   std::string ToText;
    203      1.1  joerg   for (const auto &Element : Template) {
    204      1.1  joerg     switch (Element.Type) {
    205      1.1  joerg     case TemplateElement::Literal:
    206      1.1  joerg       ToText += Element.Value;
    207      1.1  joerg       break;
    208      1.1  joerg     case TemplateElement::Identifier: {
    209      1.1  joerg       auto NodeIter = NodeMap.find(Element.Value);
    210      1.1  joerg       if (NodeIter == NodeMap.end()) {
    211      1.1  joerg         llvm::errs() << "Node " << Element.Value
    212      1.1  joerg                      << " used in replacement template not bound in Matcher \n";
    213      1.1  joerg         llvm::report_fatal_error("Unbound node in replacement template.");
    214      1.1  joerg       }
    215      1.1  joerg       CharSourceRange Source =
    216      1.1  joerg           CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
    217      1.1  joerg       ToText += Lexer::getSourceText(Source, *Result.SourceManager,
    218      1.1  joerg                                      Result.Context->getLangOpts());
    219      1.1  joerg       break;
    220      1.1  joerg     }
    221      1.1  joerg     }
    222      1.1  joerg   }
    223      1.1  joerg   if (NodeMap.count(FromId) == 0) {
    224      1.1  joerg     llvm::errs() << "Node to be replaced " << FromId
    225      1.1  joerg                  << " not bound in query.\n";
    226      1.1  joerg     llvm::report_fatal_error("FromId node not bound in MatchResult");
    227      1.1  joerg   }
    228      1.1  joerg   auto Replacement =
    229      1.1  joerg       tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
    230      1.1  joerg                            Result.Context->getLangOpts());
    231      1.1  joerg   llvm::Error Err = Replace.add(Replacement);
    232      1.1  joerg   if (Err) {
    233      1.1  joerg     llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
    234      1.1  joerg                  << "! " << llvm::toString(std::move(Err)) << "\n";
    235      1.1  joerg     llvm::report_fatal_error("Replacement failed");
    236      1.1  joerg   }
    237      1.1  joerg }
    238      1.1  joerg 
    239      1.1  joerg } // end namespace tooling
    240      1.1  joerg } // end namespace clang
    241