Home | History | Annotate | Line # | Download | only in Refactoring
      1 //===--- RecursiveSymbolVisitor.h - Clang refactoring library -------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 ///
      9 /// \file
     10 /// A wrapper class around \c RecursiveASTVisitor that visits each
     11 /// occurrences of a named symbol.
     12 ///
     13 //===----------------------------------------------------------------------===//
     14 
     15 #ifndef LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H
     16 #define LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H
     17 
     18 #include "clang/AST/AST.h"
     19 #include "clang/AST/RecursiveASTVisitor.h"
     20 #include "clang/Lex/Lexer.h"
     21 
     22 namespace clang {
     23 namespace tooling {
     24 
     25 /// Traverses the AST and visits the occurrence of each named symbol in the
     26 /// given nodes.
     27 template <typename T>
     28 class RecursiveSymbolVisitor
     29     : public RecursiveASTVisitor<RecursiveSymbolVisitor<T>> {
     30   using BaseType = RecursiveASTVisitor<RecursiveSymbolVisitor<T>>;
     31 
     32 public:
     33   RecursiveSymbolVisitor(const SourceManager &SM, const LangOptions &LangOpts)
     34       : SM(SM), LangOpts(LangOpts) {}
     35 
     36   bool visitSymbolOccurrence(const NamedDecl *ND,
     37                              ArrayRef<SourceRange> NameRanges) {
     38     return true;
     39   }
     40 
     41   // Declaration visitors:
     42 
     43   bool VisitNamedDecl(const NamedDecl *D) {
     44     return isa<CXXConversionDecl>(D) ? true : visit(D, D->getLocation());
     45   }
     46 
     47   bool VisitCXXConstructorDecl(const CXXConstructorDecl *CD) {
     48     for (const auto *Initializer : CD->inits()) {
     49       // Ignore implicit initializers.
     50       if (!Initializer->isWritten())
     51         continue;
     52       if (const FieldDecl *FD = Initializer->getMember()) {
     53         if (!visit(FD, Initializer->getSourceLocation(),
     54                    Lexer::getLocForEndOfToken(Initializer->getSourceLocation(),
     55                                               0, SM, LangOpts)))
     56           return false;
     57       }
     58     }
     59     return true;
     60   }
     61 
     62   // Expression visitors:
     63 
     64   bool VisitDeclRefExpr(const DeclRefExpr *Expr) {
     65     return visit(Expr->getFoundDecl(), Expr->getLocation());
     66   }
     67 
     68   bool VisitMemberExpr(const MemberExpr *Expr) {
     69     return visit(Expr->getFoundDecl().getDecl(), Expr->getMemberLoc());
     70   }
     71 
     72   bool VisitOffsetOfExpr(const OffsetOfExpr *S) {
     73     for (unsigned I = 0, E = S->getNumComponents(); I != E; ++I) {
     74       const OffsetOfNode &Component = S->getComponent(I);
     75       if (Component.getKind() == OffsetOfNode::Field) {
     76         if (!visit(Component.getField(), Component.getEndLoc()))
     77           return false;
     78       }
     79       // FIXME: Try to resolve dependent field references.
     80     }
     81     return true;
     82   }
     83 
     84   // Other visitors:
     85 
     86   bool VisitTypeLoc(const TypeLoc Loc) {
     87     const SourceLocation TypeBeginLoc = Loc.getBeginLoc();
     88     const SourceLocation TypeEndLoc =
     89         Lexer::getLocForEndOfToken(TypeBeginLoc, 0, SM, LangOpts);
     90     if (const auto *TemplateTypeParm =
     91             dyn_cast<TemplateTypeParmType>(Loc.getType())) {
     92       if (!visit(TemplateTypeParm->getDecl(), TypeBeginLoc, TypeEndLoc))
     93         return false;
     94     }
     95     if (const auto *TemplateSpecType =
     96             dyn_cast<TemplateSpecializationType>(Loc.getType())) {
     97       if (!visit(TemplateSpecType->getTemplateName().getAsTemplateDecl(),
     98                  TypeBeginLoc, TypeEndLoc))
     99         return false;
    100     }
    101     if (const Type *TP = Loc.getTypePtr()) {
    102       if (TP->getTypeClass() == clang::Type::Record)
    103         return visit(TP->getAsCXXRecordDecl(), TypeBeginLoc, TypeEndLoc);
    104     }
    105     return true;
    106   }
    107 
    108   bool VisitTypedefTypeLoc(TypedefTypeLoc TL) {
    109     const SourceLocation TypeEndLoc =
    110         Lexer::getLocForEndOfToken(TL.getBeginLoc(), 0, SM, LangOpts);
    111     return visit(TL.getTypedefNameDecl(), TL.getBeginLoc(), TypeEndLoc);
    112   }
    113 
    114   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
    115     // The base visitor will visit NNSL prefixes, so we should only look at
    116     // the current NNS.
    117     if (NNS) {
    118       const NamespaceDecl *ND = NNS.getNestedNameSpecifier()->getAsNamespace();
    119       if (!visit(ND, NNS.getLocalBeginLoc(), NNS.getLocalEndLoc()))
    120         return false;
    121     }
    122     return BaseType::TraverseNestedNameSpecifierLoc(NNS);
    123   }
    124 
    125   bool VisitDesignatedInitExpr(const DesignatedInitExpr *E) {
    126     for (const DesignatedInitExpr::Designator &D : E->designators()) {
    127       if (D.isFieldDesignator() && D.getField()) {
    128         const FieldDecl *Decl = D.getField();
    129         if (!visit(Decl, D.getFieldLoc(), D.getFieldLoc()))
    130           return false;
    131       }
    132     }
    133     return true;
    134   }
    135 
    136 private:
    137   const SourceManager &SM;
    138   const LangOptions &LangOpts;
    139 
    140   bool visit(const NamedDecl *ND, SourceLocation BeginLoc,
    141              SourceLocation EndLoc) {
    142     return static_cast<T *>(this)->visitSymbolOccurrence(
    143         ND, SourceRange(BeginLoc, EndLoc));
    144   }
    145   bool visit(const NamedDecl *ND, SourceLocation Loc) {
    146     return visit(ND, Loc, Lexer::getLocForEndOfToken(Loc, 0, SM, LangOpts));
    147   }
    148 };
    149 
    150 } // end namespace tooling
    151 } // end namespace clang
    152 
    153 #endif // LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H
    154