Home | History | Annotate | Line # | Download | only in ASTMatchers
      1 //===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 //  Implements an algorithm to efficiently search for matches on AST nodes.
     10 //  Uses memoization to support recursive matches like HasDescendant.
     11 //
     12 //  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
     13 //  calling the Matches(...) method of each matcher we are running on each
     14 //  AST node. The matcher can recurse via the ASTMatchFinder interface.
     15 //
     16 //===----------------------------------------------------------------------===//
     17 
     18 #include "clang/ASTMatchers/ASTMatchFinder.h"
     19 #include "clang/AST/ASTConsumer.h"
     20 #include "clang/AST/ASTContext.h"
     21 #include "clang/AST/RecursiveASTVisitor.h"
     22 #include "llvm/ADT/DenseMap.h"
     23 #include "llvm/ADT/StringMap.h"
     24 #include "llvm/Support/Timer.h"
     25 #include <deque>
     26 #include <memory>
     27 #include <set>
     28 
     29 namespace clang {
     30 namespace ast_matchers {
     31 namespace internal {
     32 namespace {
     33 
     34 typedef MatchFinder::MatchCallback MatchCallback;
     35 
     36 // The maximum number of memoization entries to store.
     37 // 10k has been experimentally found to give a good trade-off
     38 // of performance vs. memory consumption by running matcher
     39 // that match on every statement over a very large codebase.
     40 //
     41 // FIXME: Do some performance optimization in general and
     42 // revisit this number; also, put up micro-benchmarks that we can
     43 // optimize this on.
     44 static const unsigned MaxMemoizationEntries = 10000;
     45 
     46 enum class MatchType {
     47   Ancestors,
     48 
     49   Descendants,
     50   Child,
     51 };
     52 
     53 // We use memoization to avoid running the same matcher on the same
     54 // AST node twice.  This struct is the key for looking up match
     55 // result.  It consists of an ID of the MatcherInterface (for
     56 // identifying the matcher), a pointer to the AST node and the
     57 // bound nodes before the matcher was executed.
     58 //
     59 // We currently only memoize on nodes whose pointers identify the
     60 // nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc).
     61 // For \c QualType and \c TypeLoc it is possible to implement
     62 // generation of keys for each type.
     63 // FIXME: Benchmark whether memoization of non-pointer typed nodes
     64 // provides enough benefit for the additional amount of code.
     65 struct MatchKey {
     66   DynTypedMatcher::MatcherIDType MatcherID;
     67   DynTypedNode Node;
     68   BoundNodesTreeBuilder BoundNodes;
     69   TraversalKind Traversal = TK_AsIs;
     70   MatchType Type;
     71 
     72   bool operator<(const MatchKey &Other) const {
     73     return std::tie(Traversal, Type, MatcherID, Node, BoundNodes) <
     74            std::tie(Other.Traversal, Other.Type, Other.MatcherID, Other.Node,
     75                     Other.BoundNodes);
     76   }
     77 };
     78 
     79 // Used to store the result of a match and possibly bound nodes.
     80 struct MemoizedMatchResult {
     81   bool ResultOfMatch;
     82   BoundNodesTreeBuilder Nodes;
     83 };
     84 
     85 // A RecursiveASTVisitor that traverses all children or all descendants of
     86 // a node.
     87 class MatchChildASTVisitor
     88     : public RecursiveASTVisitor<MatchChildASTVisitor> {
     89 public:
     90   typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
     91 
     92   // Creates an AST visitor that matches 'matcher' on all children or
     93   // descendants of a traversed node. max_depth is the maximum depth
     94   // to traverse: use 1 for matching the children and INT_MAX for
     95   // matching the descendants.
     96   MatchChildASTVisitor(const DynTypedMatcher *Matcher, ASTMatchFinder *Finder,
     97                        BoundNodesTreeBuilder *Builder, int MaxDepth,
     98                        bool IgnoreImplicitChildren,
     99                        ASTMatchFinder::BindKind Bind)
    100       : Matcher(Matcher), Finder(Finder), Builder(Builder), CurrentDepth(0),
    101         MaxDepth(MaxDepth), IgnoreImplicitChildren(IgnoreImplicitChildren),
    102         Bind(Bind), Matches(false) {}
    103 
    104   // Returns true if a match is found in the subtree rooted at the
    105   // given AST node. This is done via a set of mutually recursive
    106   // functions. Here's how the recursion is done (the  *wildcard can
    107   // actually be Decl, Stmt, or Type):
    108   //
    109   //   - Traverse(node) calls BaseTraverse(node) when it needs
    110   //     to visit the descendants of node.
    111   //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
    112   //     Traverse*(c) for each child c of 'node'.
    113   //   - Traverse*(c) in turn calls Traverse(c), completing the
    114   //     recursion.
    115   bool findMatch(const DynTypedNode &DynNode) {
    116     reset();
    117     if (const Decl *D = DynNode.get<Decl>())
    118       traverse(*D);
    119     else if (const Stmt *S = DynNode.get<Stmt>())
    120       traverse(*S);
    121     else if (const NestedNameSpecifier *NNS =
    122              DynNode.get<NestedNameSpecifier>())
    123       traverse(*NNS);
    124     else if (const NestedNameSpecifierLoc *NNSLoc =
    125              DynNode.get<NestedNameSpecifierLoc>())
    126       traverse(*NNSLoc);
    127     else if (const QualType *Q = DynNode.get<QualType>())
    128       traverse(*Q);
    129     else if (const TypeLoc *T = DynNode.get<TypeLoc>())
    130       traverse(*T);
    131     else if (const auto *C = DynNode.get<CXXCtorInitializer>())
    132       traverse(*C);
    133     else if (const TemplateArgumentLoc *TALoc =
    134                  DynNode.get<TemplateArgumentLoc>())
    135       traverse(*TALoc);
    136     // FIXME: Add other base types after adding tests.
    137 
    138     // It's OK to always overwrite the bound nodes, as if there was
    139     // no match in this recursive branch, the result set is empty
    140     // anyway.
    141     *Builder = ResultBindings;
    142 
    143     return Matches;
    144   }
    145 
    146   // The following are overriding methods from the base visitor class.
    147   // They are public only to allow CRTP to work. They are *not *part
    148   // of the public API of this class.
    149   bool TraverseDecl(Decl *DeclNode) {
    150 
    151     if (DeclNode && DeclNode->isImplicit() &&
    152         Finder->isTraversalIgnoringImplicitNodes())
    153       return baseTraverse(*DeclNode);
    154 
    155     ScopedIncrement ScopedDepth(&CurrentDepth);
    156     return (DeclNode == nullptr) || traverse(*DeclNode);
    157   }
    158 
    159   Stmt *getStmtToTraverse(Stmt *StmtNode) {
    160     Stmt *StmtToTraverse = StmtNode;
    161     if (auto *ExprNode = dyn_cast_or_null<Expr>(StmtNode)) {
    162       auto *LambdaNode = dyn_cast_or_null<LambdaExpr>(StmtNode);
    163       if (LambdaNode && Finder->isTraversalIgnoringImplicitNodes())
    164         StmtToTraverse = LambdaNode;
    165       else
    166         StmtToTraverse =
    167             Finder->getASTContext().getParentMapContext().traverseIgnored(
    168                 ExprNode);
    169     }
    170     return StmtToTraverse;
    171   }
    172 
    173   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr) {
    174     // If we need to keep track of the depth, we can't perform data recursion.
    175     if (CurrentDepth == 0 || (CurrentDepth <= MaxDepth && MaxDepth < INT_MAX))
    176       Queue = nullptr;
    177 
    178     ScopedIncrement ScopedDepth(&CurrentDepth);
    179     Stmt *StmtToTraverse = getStmtToTraverse(StmtNode);
    180     if (!StmtToTraverse)
    181       return true;
    182 
    183     if (IgnoreImplicitChildren && isa<CXXDefaultArgExpr>(StmtNode))
    184       return true;
    185 
    186     if (!match(*StmtToTraverse))
    187       return false;
    188     return VisitorBase::TraverseStmt(StmtToTraverse, Queue);
    189   }
    190   // We assume that the QualType and the contained type are on the same
    191   // hierarchy level. Thus, we try to match either of them.
    192   bool TraverseType(QualType TypeNode) {
    193     if (TypeNode.isNull())
    194       return true;
    195     ScopedIncrement ScopedDepth(&CurrentDepth);
    196     // Match the Type.
    197     if (!match(*TypeNode))
    198       return false;
    199     // The QualType is matched inside traverse.
    200     return traverse(TypeNode);
    201   }
    202   // We assume that the TypeLoc, contained QualType and contained Type all are
    203   // on the same hierarchy level. Thus, we try to match all of them.
    204   bool TraverseTypeLoc(TypeLoc TypeLocNode) {
    205     if (TypeLocNode.isNull())
    206       return true;
    207     ScopedIncrement ScopedDepth(&CurrentDepth);
    208     // Match the Type.
    209     if (!match(*TypeLocNode.getType()))
    210       return false;
    211     // Match the QualType.
    212     if (!match(TypeLocNode.getType()))
    213       return false;
    214     // The TypeLoc is matched inside traverse.
    215     return traverse(TypeLocNode);
    216   }
    217   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
    218     ScopedIncrement ScopedDepth(&CurrentDepth);
    219     return (NNS == nullptr) || traverse(*NNS);
    220   }
    221   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
    222     if (!NNS)
    223       return true;
    224     ScopedIncrement ScopedDepth(&CurrentDepth);
    225     if (!match(*NNS.getNestedNameSpecifier()))
    226       return false;
    227     return traverse(NNS);
    228   }
    229   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit) {
    230     if (!CtorInit)
    231       return true;
    232     ScopedIncrement ScopedDepth(&CurrentDepth);
    233     return traverse(*CtorInit);
    234   }
    235   bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL) {
    236     ScopedIncrement ScopedDepth(&CurrentDepth);
    237     return traverse(TAL);
    238   }
    239   bool TraverseCXXForRangeStmt(CXXForRangeStmt *Node) {
    240     if (!Finder->isTraversalIgnoringImplicitNodes())
    241       return VisitorBase::TraverseCXXForRangeStmt(Node);
    242     if (!Node)
    243       return true;
    244     ScopedIncrement ScopedDepth(&CurrentDepth);
    245     if (auto *Init = Node->getInit())
    246       if (!traverse(*Init))
    247         return false;
    248     if (!match(*Node->getLoopVariable()))
    249       return false;
    250     if (match(*Node->getRangeInit()))
    251       if (!VisitorBase::TraverseStmt(Node->getRangeInit()))
    252         return false;
    253     if (!match(*Node->getBody()))
    254       return false;
    255     return VisitorBase::TraverseStmt(Node->getBody());
    256   }
    257   bool TraverseCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *Node) {
    258     if (!Finder->isTraversalIgnoringImplicitNodes())
    259       return VisitorBase::TraverseCXXRewrittenBinaryOperator(Node);
    260     if (!Node)
    261       return true;
    262     ScopedIncrement ScopedDepth(&CurrentDepth);
    263 
    264     return match(*Node->getLHS()) && match(*Node->getRHS());
    265   }
    266   bool TraverseLambdaExpr(LambdaExpr *Node) {
    267     if (!Finder->isTraversalIgnoringImplicitNodes())
    268       return VisitorBase::TraverseLambdaExpr(Node);
    269     if (!Node)
    270       return true;
    271     ScopedIncrement ScopedDepth(&CurrentDepth);
    272 
    273     for (unsigned I = 0, N = Node->capture_size(); I != N; ++I) {
    274       const auto *C = Node->capture_begin() + I;
    275       if (!C->isExplicit())
    276         continue;
    277       if (Node->isInitCapture(C) && !match(*C->getCapturedVar()))
    278         return false;
    279       if (!match(*Node->capture_init_begin()[I]))
    280         return false;
    281     }
    282 
    283     if (const auto *TPL = Node->getTemplateParameterList()) {
    284       for (const auto *TP : *TPL) {
    285         if (!match(*TP))
    286           return false;
    287       }
    288     }
    289 
    290     for (const auto *P : Node->getCallOperator()->parameters()) {
    291       if (!match(*P))
    292         return false;
    293     }
    294 
    295     if (!match(*Node->getBody()))
    296       return false;
    297 
    298     return VisitorBase::TraverseStmt(Node->getBody());
    299   }
    300 
    301   bool shouldVisitTemplateInstantiations() const { return true; }
    302   bool shouldVisitImplicitCode() const { return !IgnoreImplicitChildren; }
    303 
    304 private:
    305   // Used for updating the depth during traversal.
    306   struct ScopedIncrement {
    307     explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
    308     ~ScopedIncrement() { --(*Depth); }
    309 
    310    private:
    311     int *Depth;
    312   };
    313 
    314   // Resets the state of this object.
    315   void reset() {
    316     Matches = false;
    317     CurrentDepth = 0;
    318   }
    319 
    320   // Forwards the call to the corresponding Traverse*() method in the
    321   // base visitor class.
    322   bool baseTraverse(const Decl &DeclNode) {
    323     return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
    324   }
    325   bool baseTraverse(const Stmt &StmtNode) {
    326     return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
    327   }
    328   bool baseTraverse(QualType TypeNode) {
    329     return VisitorBase::TraverseType(TypeNode);
    330   }
    331   bool baseTraverse(TypeLoc TypeLocNode) {
    332     return VisitorBase::TraverseTypeLoc(TypeLocNode);
    333   }
    334   bool baseTraverse(const NestedNameSpecifier &NNS) {
    335     return VisitorBase::TraverseNestedNameSpecifier(
    336         const_cast<NestedNameSpecifier*>(&NNS));
    337   }
    338   bool baseTraverse(NestedNameSpecifierLoc NNS) {
    339     return VisitorBase::TraverseNestedNameSpecifierLoc(NNS);
    340   }
    341   bool baseTraverse(const CXXCtorInitializer &CtorInit) {
    342     return VisitorBase::TraverseConstructorInitializer(
    343         const_cast<CXXCtorInitializer *>(&CtorInit));
    344   }
    345   bool baseTraverse(TemplateArgumentLoc TAL) {
    346     return VisitorBase::TraverseTemplateArgumentLoc(TAL);
    347   }
    348 
    349   // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
    350   //   0 < CurrentDepth <= MaxDepth.
    351   //
    352   // Returns 'true' if traversal should continue after this function
    353   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
    354   template <typename T>
    355   bool match(const T &Node) {
    356     if (CurrentDepth == 0 || CurrentDepth > MaxDepth) {
    357       return true;
    358     }
    359     if (Bind != ASTMatchFinder::BK_All) {
    360       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
    361       if (Matcher->matches(DynTypedNode::create(Node), Finder,
    362                            &RecursiveBuilder)) {
    363         Matches = true;
    364         ResultBindings.addMatch(RecursiveBuilder);
    365         return false; // Abort as soon as a match is found.
    366       }
    367     } else {
    368       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
    369       if (Matcher->matches(DynTypedNode::create(Node), Finder,
    370                            &RecursiveBuilder)) {
    371         // After the first match the matcher succeeds.
    372         Matches = true;
    373         ResultBindings.addMatch(RecursiveBuilder);
    374       }
    375     }
    376     return true;
    377   }
    378 
    379   // Traverses the subtree rooted at 'Node'; returns true if the
    380   // traversal should continue after this function returns.
    381   template <typename T>
    382   bool traverse(const T &Node) {
    383     static_assert(IsBaseType<T>::value,
    384                   "traverse can only be instantiated with base type");
    385     if (!match(Node))
    386       return false;
    387     return baseTraverse(Node);
    388   }
    389 
    390   const DynTypedMatcher *const Matcher;
    391   ASTMatchFinder *const Finder;
    392   BoundNodesTreeBuilder *const Builder;
    393   BoundNodesTreeBuilder ResultBindings;
    394   int CurrentDepth;
    395   const int MaxDepth;
    396   const bool IgnoreImplicitChildren;
    397   const ASTMatchFinder::BindKind Bind;
    398   bool Matches;
    399 };
    400 
    401 // Controls the outermost traversal of the AST and allows to match multiple
    402 // matchers.
    403 class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
    404                         public ASTMatchFinder {
    405 public:
    406   MatchASTVisitor(const MatchFinder::MatchersByType *Matchers,
    407                   const MatchFinder::MatchFinderOptions &Options)
    408       : Matchers(Matchers), Options(Options), ActiveASTContext(nullptr) {}
    409 
    410   ~MatchASTVisitor() override {
    411     if (Options.CheckProfiling) {
    412       Options.CheckProfiling->Records = std::move(TimeByBucket);
    413     }
    414   }
    415 
    416   void onStartOfTranslationUnit() {
    417     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    418     TimeBucketRegion Timer;
    419     for (MatchCallback *MC : Matchers->AllCallbacks) {
    420       if (EnableCheckProfiling)
    421         Timer.setBucket(&TimeByBucket[MC->getID()]);
    422       MC->onStartOfTranslationUnit();
    423     }
    424   }
    425 
    426   void onEndOfTranslationUnit() {
    427     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    428     TimeBucketRegion Timer;
    429     for (MatchCallback *MC : Matchers->AllCallbacks) {
    430       if (EnableCheckProfiling)
    431         Timer.setBucket(&TimeByBucket[MC->getID()]);
    432       MC->onEndOfTranslationUnit();
    433     }
    434   }
    435 
    436   void set_active_ast_context(ASTContext *NewActiveASTContext) {
    437     ActiveASTContext = NewActiveASTContext;
    438   }
    439 
    440   // The following Visit*() and Traverse*() functions "override"
    441   // methods in RecursiveASTVisitor.
    442 
    443   bool VisitTypedefNameDecl(TypedefNameDecl *DeclNode) {
    444     // When we see 'typedef A B', we add name 'B' to the set of names
    445     // A's canonical type maps to.  This is necessary for implementing
    446     // isDerivedFrom(x) properly, where x can be the name of the base
    447     // class or any of its aliases.
    448     //
    449     // In general, the is-alias-of (as defined by typedefs) relation
    450     // is tree-shaped, as you can typedef a type more than once.  For
    451     // example,
    452     //
    453     //   typedef A B;
    454     //   typedef A C;
    455     //   typedef C D;
    456     //   typedef C E;
    457     //
    458     // gives you
    459     //
    460     //   A
    461     //   |- B
    462     //   `- C
    463     //      |- D
    464     //      `- E
    465     //
    466     // It is wrong to assume that the relation is a chain.  A correct
    467     // implementation of isDerivedFrom() needs to recognize that B and
    468     // E are aliases, even though neither is a typedef of the other.
    469     // Therefore, we cannot simply walk through one typedef chain to
    470     // find out whether the type name matches.
    471     const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
    472     const Type *CanonicalType =  // root of the typedef tree
    473         ActiveASTContext->getCanonicalType(TypeNode);
    474     TypeAliases[CanonicalType].insert(DeclNode);
    475     return true;
    476   }
    477 
    478   bool VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl *CAD) {
    479     const ObjCInterfaceDecl *InterfaceDecl = CAD->getClassInterface();
    480     CompatibleAliases[InterfaceDecl].insert(CAD);
    481     return true;
    482   }
    483 
    484   bool TraverseDecl(Decl *DeclNode);
    485   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr);
    486   bool TraverseType(QualType TypeNode);
    487   bool TraverseTypeLoc(TypeLoc TypeNode);
    488   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS);
    489   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS);
    490   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit);
    491   bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL);
    492 
    493   bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) {
    494     if (auto *RF = dyn_cast<CXXForRangeStmt>(S)) {
    495       {
    496         ASTNodeNotAsIsSourceScope RAII(this, true);
    497         TraverseStmt(RF->getInit());
    498         // Don't traverse under the loop variable
    499         match(*RF->getLoopVariable());
    500         TraverseStmt(RF->getRangeInit());
    501       }
    502       {
    503         ASTNodeNotSpelledInSourceScope RAII(this, true);
    504         for (auto *SubStmt : RF->children()) {
    505           if (SubStmt != RF->getBody())
    506             TraverseStmt(SubStmt);
    507         }
    508       }
    509       TraverseStmt(RF->getBody());
    510       return true;
    511     } else if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) {
    512       {
    513         ASTNodeNotAsIsSourceScope RAII(this, true);
    514         TraverseStmt(const_cast<Expr *>(RBO->getLHS()));
    515         TraverseStmt(const_cast<Expr *>(RBO->getRHS()));
    516       }
    517       {
    518         ASTNodeNotSpelledInSourceScope RAII(this, true);
    519         for (auto *SubStmt : RBO->children()) {
    520           TraverseStmt(SubStmt);
    521         }
    522       }
    523       return true;
    524     } else if (auto *LE = dyn_cast<LambdaExpr>(S)) {
    525       for (auto I : llvm::zip(LE->captures(), LE->capture_inits())) {
    526         auto C = std::get<0>(I);
    527         ASTNodeNotSpelledInSourceScope RAII(
    528             this, TraversingASTNodeNotSpelledInSource || !C.isExplicit());
    529         TraverseLambdaCapture(LE, &C, std::get<1>(I));
    530       }
    531 
    532       {
    533         ASTNodeNotSpelledInSourceScope RAII(this, true);
    534         TraverseDecl(LE->getLambdaClass());
    535       }
    536       {
    537         ASTNodeNotAsIsSourceScope RAII(this, true);
    538 
    539         // We need to poke around to find the bits that might be explicitly
    540         // written.
    541         TypeLoc TL = LE->getCallOperator()->getTypeSourceInfo()->getTypeLoc();
    542         FunctionProtoTypeLoc Proto = TL.getAsAdjusted<FunctionProtoTypeLoc>();
    543 
    544         if (auto *TPL = LE->getTemplateParameterList()) {
    545           for (NamedDecl *D : *TPL) {
    546             TraverseDecl(D);
    547           }
    548           if (Expr *RequiresClause = TPL->getRequiresClause()) {
    549             TraverseStmt(RequiresClause);
    550           }
    551         }
    552 
    553         if (LE->hasExplicitParameters()) {
    554           // Visit parameters.
    555           for (ParmVarDecl *Param : Proto.getParams())
    556             TraverseDecl(Param);
    557         }
    558 
    559         const auto *T = Proto.getTypePtr();
    560         for (const auto &E : T->exceptions())
    561           TraverseType(E);
    562 
    563         if (Expr *NE = T->getNoexceptExpr())
    564           TraverseStmt(NE, Queue);
    565 
    566         if (LE->hasExplicitResultType())
    567           TraverseTypeLoc(Proto.getReturnLoc());
    568         TraverseStmt(LE->getTrailingRequiresClause());
    569       }
    570 
    571       TraverseStmt(LE->getBody());
    572       return true;
    573     }
    574     return RecursiveASTVisitor<MatchASTVisitor>::dataTraverseNode(S, Queue);
    575   }
    576 
    577   // Matches children or descendants of 'Node' with 'BaseMatcher'.
    578   bool memoizedMatchesRecursively(const DynTypedNode &Node, ASTContext &Ctx,
    579                                   const DynTypedMatcher &Matcher,
    580                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
    581                                   BindKind Bind) {
    582     // For AST-nodes that don't have an identity, we can't memoize.
    583     if (!Node.getMemoizationData() || !Builder->isComparable())
    584       return matchesRecursively(Node, Matcher, Builder, MaxDepth, Bind);
    585 
    586     MatchKey Key;
    587     Key.MatcherID = Matcher.getID();
    588     Key.Node = Node;
    589     // Note that we key on the bindings *before* the match.
    590     Key.BoundNodes = *Builder;
    591     Key.Traversal = Ctx.getParentMapContext().getTraversalKind();
    592     // Memoize result even doing a single-level match, it might be expensive.
    593     Key.Type = MaxDepth == 1 ? MatchType::Child : MatchType::Descendants;
    594     MemoizationMap::iterator I = ResultCache.find(Key);
    595     if (I != ResultCache.end()) {
    596       *Builder = I->second.Nodes;
    597       return I->second.ResultOfMatch;
    598     }
    599 
    600     MemoizedMatchResult Result;
    601     Result.Nodes = *Builder;
    602     Result.ResultOfMatch =
    603         matchesRecursively(Node, Matcher, &Result.Nodes, MaxDepth, Bind);
    604 
    605     MemoizedMatchResult &CachedResult = ResultCache[Key];
    606     CachedResult = std::move(Result);
    607 
    608     *Builder = CachedResult.Nodes;
    609     return CachedResult.ResultOfMatch;
    610   }
    611 
    612   // Matches children or descendants of 'Node' with 'BaseMatcher'.
    613   bool matchesRecursively(const DynTypedNode &Node,
    614                           const DynTypedMatcher &Matcher,
    615                           BoundNodesTreeBuilder *Builder, int MaxDepth,
    616                           BindKind Bind) {
    617     bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
    618                            TraversingASTChildrenNotSpelledInSource;
    619 
    620     bool IgnoreImplicitChildren = false;
    621 
    622     if (isTraversalIgnoringImplicitNodes()) {
    623       IgnoreImplicitChildren = true;
    624     }
    625 
    626     ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
    627 
    628     MatchChildASTVisitor Visitor(&Matcher, this, Builder, MaxDepth,
    629                                  IgnoreImplicitChildren, Bind);
    630     return Visitor.findMatch(Node);
    631   }
    632 
    633   bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
    634                           const Matcher<NamedDecl> &Base,
    635                           BoundNodesTreeBuilder *Builder,
    636                           bool Directly) override;
    637 
    638   bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration,
    639                               const Matcher<NamedDecl> &Base,
    640                               BoundNodesTreeBuilder *Builder,
    641                               bool Directly) override;
    642 
    643   // Implements ASTMatchFinder::matchesChildOf.
    644   bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx,
    645                       const DynTypedMatcher &Matcher,
    646                       BoundNodesTreeBuilder *Builder, BindKind Bind) override {
    647     if (ResultCache.size() > MaxMemoizationEntries)
    648       ResultCache.clear();
    649     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Bind);
    650   }
    651   // Implements ASTMatchFinder::matchesDescendantOf.
    652   bool matchesDescendantOf(const DynTypedNode &Node, ASTContext &Ctx,
    653                            const DynTypedMatcher &Matcher,
    654                            BoundNodesTreeBuilder *Builder,
    655                            BindKind Bind) override {
    656     if (ResultCache.size() > MaxMemoizationEntries)
    657       ResultCache.clear();
    658     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX,
    659                                       Bind);
    660   }
    661   // Implements ASTMatchFinder::matchesAncestorOf.
    662   bool matchesAncestorOf(const DynTypedNode &Node, ASTContext &Ctx,
    663                          const DynTypedMatcher &Matcher,
    664                          BoundNodesTreeBuilder *Builder,
    665                          AncestorMatchMode MatchMode) override {
    666     // Reset the cache outside of the recursive call to make sure we
    667     // don't invalidate any iterators.
    668     if (ResultCache.size() > MaxMemoizationEntries)
    669       ResultCache.clear();
    670     if (MatchMode == AncestorMatchMode::AMM_ParentOnly)
    671       return matchesParentOf(Node, Matcher, Builder);
    672     return matchesAnyAncestorOf(Node, Ctx, Matcher, Builder);
    673   }
    674 
    675   // Matches all registered matchers on the given node and calls the
    676   // result callback for every node that matches.
    677   void match(const DynTypedNode &Node) {
    678     // FIXME: Improve this with a switch or a visitor pattern.
    679     if (auto *N = Node.get<Decl>()) {
    680       match(*N);
    681     } else if (auto *N = Node.get<Stmt>()) {
    682       match(*N);
    683     } else if (auto *N = Node.get<Type>()) {
    684       match(*N);
    685     } else if (auto *N = Node.get<QualType>()) {
    686       match(*N);
    687     } else if (auto *N = Node.get<NestedNameSpecifier>()) {
    688       match(*N);
    689     } else if (auto *N = Node.get<NestedNameSpecifierLoc>()) {
    690       match(*N);
    691     } else if (auto *N = Node.get<TypeLoc>()) {
    692       match(*N);
    693     } else if (auto *N = Node.get<CXXCtorInitializer>()) {
    694       match(*N);
    695     } else if (auto *N = Node.get<TemplateArgumentLoc>()) {
    696       match(*N);
    697     }
    698   }
    699 
    700   template <typename T> void match(const T &Node) {
    701     matchDispatch(&Node);
    702   }
    703 
    704   // Implements ASTMatchFinder::getASTContext.
    705   ASTContext &getASTContext() const override { return *ActiveASTContext; }
    706 
    707   bool shouldVisitTemplateInstantiations() const { return true; }
    708   bool shouldVisitImplicitCode() const { return true; }
    709 
    710   // We visit the lambda body explicitly, so instruct the RAV
    711   // to not visit it on our behalf too.
    712   bool shouldVisitLambdaBody() const { return false; }
    713 
    714   bool IsMatchingInASTNodeNotSpelledInSource() const override {
    715     return TraversingASTNodeNotSpelledInSource;
    716   }
    717   bool isMatchingChildrenNotSpelledInSource() const override {
    718     return TraversingASTChildrenNotSpelledInSource;
    719   }
    720   void setMatchingChildrenNotSpelledInSource(bool Set) override {
    721     TraversingASTChildrenNotSpelledInSource = Set;
    722   }
    723 
    724   bool IsMatchingInASTNodeNotAsIs() const override {
    725     return TraversingASTNodeNotAsIs;
    726   }
    727 
    728   bool TraverseTemplateInstantiations(ClassTemplateDecl *D) {
    729     ASTNodeNotSpelledInSourceScope RAII(this, true);
    730     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
    731         D);
    732   }
    733 
    734   bool TraverseTemplateInstantiations(VarTemplateDecl *D) {
    735     ASTNodeNotSpelledInSourceScope RAII(this, true);
    736     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
    737         D);
    738   }
    739 
    740   bool TraverseTemplateInstantiations(FunctionTemplateDecl *D) {
    741     ASTNodeNotSpelledInSourceScope RAII(this, true);
    742     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
    743         D);
    744   }
    745 
    746 private:
    747   bool TraversingASTNodeNotSpelledInSource = false;
    748   bool TraversingASTNodeNotAsIs = false;
    749   bool TraversingASTChildrenNotSpelledInSource = false;
    750 
    751   struct ASTNodeNotSpelledInSourceScope {
    752     ASTNodeNotSpelledInSourceScope(MatchASTVisitor *V, bool B)
    753         : MV(V), MB(V->TraversingASTNodeNotSpelledInSource) {
    754       V->TraversingASTNodeNotSpelledInSource = B;
    755     }
    756     ~ASTNodeNotSpelledInSourceScope() {
    757       MV->TraversingASTNodeNotSpelledInSource = MB;
    758     }
    759 
    760   private:
    761     MatchASTVisitor *MV;
    762     bool MB;
    763   };
    764 
    765   struct ASTNodeNotAsIsSourceScope {
    766     ASTNodeNotAsIsSourceScope(MatchASTVisitor *V, bool B)
    767         : MV(V), MB(V->TraversingASTNodeNotAsIs) {
    768       V->TraversingASTNodeNotAsIs = B;
    769     }
    770     ~ASTNodeNotAsIsSourceScope() { MV->TraversingASTNodeNotAsIs = MB; }
    771 
    772   private:
    773     MatchASTVisitor *MV;
    774     bool MB;
    775   };
    776 
    777   class TimeBucketRegion {
    778   public:
    779     TimeBucketRegion() : Bucket(nullptr) {}
    780     ~TimeBucketRegion() { setBucket(nullptr); }
    781 
    782     /// Start timing for \p NewBucket.
    783     ///
    784     /// If there was a bucket already set, it will finish the timing for that
    785     /// other bucket.
    786     /// \p NewBucket will be timed until the next call to \c setBucket() or
    787     /// until the \c TimeBucketRegion is destroyed.
    788     /// If \p NewBucket is the same as the currently timed bucket, this call
    789     /// does nothing.
    790     void setBucket(llvm::TimeRecord *NewBucket) {
    791       if (Bucket != NewBucket) {
    792         auto Now = llvm::TimeRecord::getCurrentTime(true);
    793         if (Bucket)
    794           *Bucket += Now;
    795         if (NewBucket)
    796           *NewBucket -= Now;
    797         Bucket = NewBucket;
    798       }
    799     }
    800 
    801   private:
    802     llvm::TimeRecord *Bucket;
    803   };
    804 
    805   /// Runs all the \p Matchers on \p Node.
    806   ///
    807   /// Used by \c matchDispatch() below.
    808   template <typename T, typename MC>
    809   void matchWithoutFilter(const T &Node, const MC &Matchers) {
    810     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    811     TimeBucketRegion Timer;
    812     for (const auto &MP : Matchers) {
    813       if (EnableCheckProfiling)
    814         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
    815       BoundNodesTreeBuilder Builder;
    816       if (MP.first.matches(Node, this, &Builder)) {
    817         MatchVisitor Visitor(ActiveASTContext, MP.second);
    818         Builder.visitMatches(&Visitor);
    819       }
    820     }
    821   }
    822 
    823   void matchWithFilter(const DynTypedNode &DynNode) {
    824     auto Kind = DynNode.getNodeKind();
    825     auto it = MatcherFiltersMap.find(Kind);
    826     const auto &Filter =
    827         it != MatcherFiltersMap.end() ? it->second : getFilterForKind(Kind);
    828 
    829     if (Filter.empty())
    830       return;
    831 
    832     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    833     TimeBucketRegion Timer;
    834     auto &Matchers = this->Matchers->DeclOrStmt;
    835     for (unsigned short I : Filter) {
    836       auto &MP = Matchers[I];
    837       if (EnableCheckProfiling)
    838         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
    839       BoundNodesTreeBuilder Builder;
    840 
    841       {
    842         TraversalKindScope RAII(getASTContext(), MP.first.getTraversalKind());
    843         if (getASTContext().getParentMapContext().traverseIgnored(DynNode) !=
    844             DynNode)
    845           continue;
    846       }
    847 
    848       if (MP.first.matches(DynNode, this, &Builder)) {
    849         MatchVisitor Visitor(ActiveASTContext, MP.second);
    850         Builder.visitMatches(&Visitor);
    851       }
    852     }
    853   }
    854 
    855   const std::vector<unsigned short> &getFilterForKind(ASTNodeKind Kind) {
    856     auto &Filter = MatcherFiltersMap[Kind];
    857     auto &Matchers = this->Matchers->DeclOrStmt;
    858     assert((Matchers.size() < USHRT_MAX) && "Too many matchers.");
    859     for (unsigned I = 0, E = Matchers.size(); I != E; ++I) {
    860       if (Matchers[I].first.canMatchNodesOfKind(Kind)) {
    861         Filter.push_back(I);
    862       }
    863     }
    864     return Filter;
    865   }
    866 
    867   /// @{
    868   /// Overloads to pair the different node types to their matchers.
    869   void matchDispatch(const Decl *Node) {
    870     return matchWithFilter(DynTypedNode::create(*Node));
    871   }
    872   void matchDispatch(const Stmt *Node) {
    873     return matchWithFilter(DynTypedNode::create(*Node));
    874   }
    875 
    876   void matchDispatch(const Type *Node) {
    877     matchWithoutFilter(QualType(Node, 0), Matchers->Type);
    878   }
    879   void matchDispatch(const TypeLoc *Node) {
    880     matchWithoutFilter(*Node, Matchers->TypeLoc);
    881   }
    882   void matchDispatch(const QualType *Node) {
    883     matchWithoutFilter(*Node, Matchers->Type);
    884   }
    885   void matchDispatch(const NestedNameSpecifier *Node) {
    886     matchWithoutFilter(*Node, Matchers->NestedNameSpecifier);
    887   }
    888   void matchDispatch(const NestedNameSpecifierLoc *Node) {
    889     matchWithoutFilter(*Node, Matchers->NestedNameSpecifierLoc);
    890   }
    891   void matchDispatch(const CXXCtorInitializer *Node) {
    892     matchWithoutFilter(*Node, Matchers->CtorInit);
    893   }
    894   void matchDispatch(const TemplateArgumentLoc *Node) {
    895     matchWithoutFilter(*Node, Matchers->TemplateArgumentLoc);
    896   }
    897   void matchDispatch(const void *) { /* Do nothing. */ }
    898   /// @}
    899 
    900   // Returns whether a direct parent of \p Node matches \p Matcher.
    901   // Unlike matchesAnyAncestorOf there's no memoization: it doesn't save much.
    902   bool matchesParentOf(const DynTypedNode &Node, const DynTypedMatcher &Matcher,
    903                        BoundNodesTreeBuilder *Builder) {
    904     for (const auto &Parent : ActiveASTContext->getParents(Node)) {
    905       BoundNodesTreeBuilder BuilderCopy = *Builder;
    906       if (Matcher.matches(Parent, this, &BuilderCopy)) {
    907         *Builder = std::move(BuilderCopy);
    908         return true;
    909       }
    910     }
    911     return false;
    912   }
    913 
    914   // Returns whether an ancestor of \p Node matches \p Matcher.
    915   //
    916   // The order of matching (which can lead to different nodes being bound in
    917   // case there are multiple matches) is breadth first search.
    918   //
    919   // To allow memoization in the very common case of having deeply nested
    920   // expressions inside a template function, we first walk up the AST, memoizing
    921   // the result of the match along the way, as long as there is only a single
    922   // parent.
    923   //
    924   // Once there are multiple parents, the breadth first search order does not
    925   // allow simple memoization on the ancestors. Thus, we only memoize as long
    926   // as there is a single parent.
    927   //
    928   // We avoid a recursive implementation to prevent excessive stack use on
    929   // very deep ASTs (similarly to RecursiveASTVisitor's data recursion).
    930   bool matchesAnyAncestorOf(DynTypedNode Node, ASTContext &Ctx,
    931                             const DynTypedMatcher &Matcher,
    932                             BoundNodesTreeBuilder *Builder) {
    933 
    934     // Memoization keys that can be updated with the result.
    935     // These are the memoizable nodes in the chain of unique parents, which
    936     // terminates when a node has multiple parents, or matches, or is the root.
    937     std::vector<MatchKey> Keys;
    938     // When returning, update the memoization cache.
    939     auto Finish = [&](bool Matched) {
    940       for (const auto &Key : Keys) {
    941         MemoizedMatchResult &CachedResult = ResultCache[Key];
    942         CachedResult.ResultOfMatch = Matched;
    943         CachedResult.Nodes = *Builder;
    944       }
    945       return Matched;
    946     };
    947 
    948     // Loop while there's a single parent and we want to attempt memoization.
    949     DynTypedNodeList Parents{ArrayRef<DynTypedNode>()}; // after loop: size != 1
    950     for (;;) {
    951       // A cache key only makes sense if memoization is possible.
    952       if (Builder->isComparable()) {
    953         Keys.emplace_back();
    954         Keys.back().MatcherID = Matcher.getID();
    955         Keys.back().Node = Node;
    956         Keys.back().BoundNodes = *Builder;
    957         Keys.back().Traversal = Ctx.getParentMapContext().getTraversalKind();
    958         Keys.back().Type = MatchType::Ancestors;
    959 
    960         // Check the cache.
    961         MemoizationMap::iterator I = ResultCache.find(Keys.back());
    962         if (I != ResultCache.end()) {
    963           Keys.pop_back(); // Don't populate the cache for the matching node!
    964           *Builder = I->second.Nodes;
    965           return Finish(I->second.ResultOfMatch);
    966         }
    967       }
    968 
    969       Parents = ActiveASTContext->getParents(Node);
    970       // Either no parents or multiple parents: leave chain+memoize mode and
    971       // enter bfs+forgetful mode.
    972       if (Parents.size() != 1)
    973         break;
    974 
    975       // Check the next parent.
    976       Node = *Parents.begin();
    977       BoundNodesTreeBuilder BuilderCopy = *Builder;
    978       if (Matcher.matches(Node, this, &BuilderCopy)) {
    979         *Builder = std::move(BuilderCopy);
    980         return Finish(true);
    981       }
    982     }
    983     // We reached the end of the chain.
    984 
    985     if (Parents.empty()) {
    986       // Nodes may have no parents if:
    987       //  a) the node is the TranslationUnitDecl
    988       //  b) we have a limited traversal scope that excludes the parent edges
    989       //  c) there is a bug in the AST, and the node is not reachable
    990       // Usually the traversal scope is the whole AST, which precludes b.
    991       // Bugs are common enough that it's worthwhile asserting when we can.
    992 #ifndef NDEBUG
    993       if (!Node.get<TranslationUnitDecl>() &&
    994           /* Traversal scope is full AST if any of the bounds are the TU */
    995           llvm::any_of(ActiveASTContext->getTraversalScope(), [](Decl *D) {
    996             return D->getKind() == Decl::TranslationUnit;
    997           })) {
    998         llvm::errs() << "Tried to match orphan node:\n";
    999         Node.dump(llvm::errs(), *ActiveASTContext);
   1000         llvm_unreachable("Parent map should be complete!");
   1001       }
   1002 #endif
   1003     } else {
   1004       assert(Parents.size() > 1);
   1005       // BFS starting from the parents not yet considered.
   1006       // Memoization of newly visited nodes is not possible (but we still update
   1007       // results for the elements in the chain we found above).
   1008       std::deque<DynTypedNode> Queue(Parents.begin(), Parents.end());
   1009       llvm::DenseSet<const void *> Visited;
   1010       while (!Queue.empty()) {
   1011         BoundNodesTreeBuilder BuilderCopy = *Builder;
   1012         if (Matcher.matches(Queue.front(), this, &BuilderCopy)) {
   1013           *Builder = std::move(BuilderCopy);
   1014           return Finish(true);
   1015         }
   1016         for (const auto &Parent : ActiveASTContext->getParents(Queue.front())) {
   1017           // Make sure we do not visit the same node twice.
   1018           // Otherwise, we'll visit the common ancestors as often as there
   1019           // are splits on the way down.
   1020           if (Visited.insert(Parent.getMemoizationData()).second)
   1021             Queue.push_back(Parent);
   1022         }
   1023         Queue.pop_front();
   1024       }
   1025     }
   1026     return Finish(false);
   1027   }
   1028 
   1029   // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
   1030   // the aggregated bound nodes for each match.
   1031   class MatchVisitor : public BoundNodesTreeBuilder::Visitor {
   1032   public:
   1033     MatchVisitor(ASTContext* Context,
   1034                  MatchFinder::MatchCallback* Callback)
   1035       : Context(Context),
   1036         Callback(Callback) {}
   1037 
   1038     void visitMatch(const BoundNodes& BoundNodesView) override {
   1039       TraversalKindScope RAII(*Context, Callback->getCheckTraversalKind());
   1040       Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
   1041     }
   1042 
   1043   private:
   1044     ASTContext* Context;
   1045     MatchFinder::MatchCallback* Callback;
   1046   };
   1047 
   1048   // Returns true if 'TypeNode' has an alias that matches the given matcher.
   1049   bool typeHasMatchingAlias(const Type *TypeNode,
   1050                             const Matcher<NamedDecl> &Matcher,
   1051                             BoundNodesTreeBuilder *Builder) {
   1052     const Type *const CanonicalType =
   1053       ActiveASTContext->getCanonicalType(TypeNode);
   1054     auto Aliases = TypeAliases.find(CanonicalType);
   1055     if (Aliases == TypeAliases.end())
   1056       return false;
   1057     for (const TypedefNameDecl *Alias : Aliases->second) {
   1058       BoundNodesTreeBuilder Result(*Builder);
   1059       if (Matcher.matches(*Alias, this, &Result)) {
   1060         *Builder = std::move(Result);
   1061         return true;
   1062       }
   1063     }
   1064     return false;
   1065   }
   1066 
   1067   bool
   1068   objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl *InterfaceDecl,
   1069                                          const Matcher<NamedDecl> &Matcher,
   1070                                          BoundNodesTreeBuilder *Builder) {
   1071     auto Aliases = CompatibleAliases.find(InterfaceDecl);
   1072     if (Aliases == CompatibleAliases.end())
   1073       return false;
   1074     for (const ObjCCompatibleAliasDecl *Alias : Aliases->second) {
   1075       BoundNodesTreeBuilder Result(*Builder);
   1076       if (Matcher.matches(*Alias, this, &Result)) {
   1077         *Builder = std::move(Result);
   1078         return true;
   1079       }
   1080     }
   1081     return false;
   1082   }
   1083 
   1084   /// Bucket to record map.
   1085   ///
   1086   /// Used to get the appropriate bucket for each matcher.
   1087   llvm::StringMap<llvm::TimeRecord> TimeByBucket;
   1088 
   1089   const MatchFinder::MatchersByType *Matchers;
   1090 
   1091   /// Filtered list of matcher indices for each matcher kind.
   1092   ///
   1093   /// \c Decl and \c Stmt toplevel matchers usually apply to a specific node
   1094   /// kind (and derived kinds) so it is a waste to try every matcher on every
   1095   /// node.
   1096   /// We precalculate a list of matchers that pass the toplevel restrict check.
   1097   llvm::DenseMap<ASTNodeKind, std::vector<unsigned short>> MatcherFiltersMap;
   1098 
   1099   const MatchFinder::MatchFinderOptions &Options;
   1100   ASTContext *ActiveASTContext;
   1101 
   1102   // Maps a canonical type to its TypedefDecls.
   1103   llvm::DenseMap<const Type*, std::set<const TypedefNameDecl*> > TypeAliases;
   1104 
   1105   // Maps an Objective-C interface to its ObjCCompatibleAliasDecls.
   1106   llvm::DenseMap<const ObjCInterfaceDecl *,
   1107                  llvm::SmallPtrSet<const ObjCCompatibleAliasDecl *, 2>>
   1108       CompatibleAliases;
   1109 
   1110   // Maps (matcher, node) -> the match result for memoization.
   1111   typedef std::map<MatchKey, MemoizedMatchResult> MemoizationMap;
   1112   MemoizationMap ResultCache;
   1113 };
   1114 
   1115 static CXXRecordDecl *
   1116 getAsCXXRecordDeclOrPrimaryTemplate(const Type *TypeNode) {
   1117   if (auto *RD = TypeNode->getAsCXXRecordDecl())
   1118     return RD;
   1119 
   1120   // Find the innermost TemplateSpecializationType that isn't an alias template.
   1121   auto *TemplateType = TypeNode->getAs<TemplateSpecializationType>();
   1122   while (TemplateType && TemplateType->isTypeAlias())
   1123     TemplateType =
   1124         TemplateType->getAliasedType()->getAs<TemplateSpecializationType>();
   1125 
   1126   // If this is the name of a (dependent) template specialization, use the
   1127   // definition of the template, even though it might be specialized later.
   1128   if (TemplateType)
   1129     if (auto *ClassTemplate = dyn_cast_or_null<ClassTemplateDecl>(
   1130           TemplateType->getTemplateName().getAsTemplateDecl()))
   1131       return ClassTemplate->getTemplatedDecl();
   1132 
   1133   return nullptr;
   1134 }
   1135 
   1136 // Returns true if the given C++ class is directly or indirectly derived
   1137 // from a base type with the given name.  A class is not considered to be
   1138 // derived from itself.
   1139 bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
   1140                                          const Matcher<NamedDecl> &Base,
   1141                                          BoundNodesTreeBuilder *Builder,
   1142                                          bool Directly) {
   1143   if (!Declaration->hasDefinition())
   1144     return false;
   1145   for (const auto &It : Declaration->bases()) {
   1146     const Type *TypeNode = It.getType().getTypePtr();
   1147 
   1148     if (typeHasMatchingAlias(TypeNode, Base, Builder))
   1149       return true;
   1150 
   1151     // FIXME: Going to the primary template here isn't really correct, but
   1152     // unfortunately we accept a Decl matcher for the base class not a Type
   1153     // matcher, so it's the best thing we can do with our current interface.
   1154     CXXRecordDecl *ClassDecl = getAsCXXRecordDeclOrPrimaryTemplate(TypeNode);
   1155     if (!ClassDecl)
   1156       continue;
   1157     if (ClassDecl == Declaration) {
   1158       // This can happen for recursive template definitions.
   1159       continue;
   1160     }
   1161     BoundNodesTreeBuilder Result(*Builder);
   1162     if (Base.matches(*ClassDecl, this, &Result)) {
   1163       *Builder = std::move(Result);
   1164       return true;
   1165     }
   1166     if (!Directly && classIsDerivedFrom(ClassDecl, Base, Builder, Directly))
   1167       return true;
   1168   }
   1169   return false;
   1170 }
   1171 
   1172 // Returns true if the given Objective-C class is directly or indirectly
   1173 // derived from a matching base class. A class is not considered to be derived
   1174 // from itself.
   1175 bool MatchASTVisitor::objcClassIsDerivedFrom(
   1176     const ObjCInterfaceDecl *Declaration, const Matcher<NamedDecl> &Base,
   1177     BoundNodesTreeBuilder *Builder, bool Directly) {
   1178   // Check if any of the superclasses of the class match.
   1179   for (const ObjCInterfaceDecl *ClassDecl = Declaration->getSuperClass();
   1180        ClassDecl != nullptr; ClassDecl = ClassDecl->getSuperClass()) {
   1181     // Check if there are any matching compatibility aliases.
   1182     if (objcClassHasMatchingCompatibilityAlias(ClassDecl, Base, Builder))
   1183       return true;
   1184 
   1185     // Check if there are any matching type aliases.
   1186     const Type *TypeNode = ClassDecl->getTypeForDecl();
   1187     if (typeHasMatchingAlias(TypeNode, Base, Builder))
   1188       return true;
   1189 
   1190     if (Base.matches(*ClassDecl, this, Builder))
   1191       return true;
   1192 
   1193     // Not `return false` as a temporary workaround for PR43879.
   1194     if (Directly)
   1195       break;
   1196   }
   1197 
   1198   return false;
   1199 }
   1200 
   1201 bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
   1202   if (!DeclNode) {
   1203     return true;
   1204   }
   1205 
   1206   bool ScopedTraversal =
   1207       TraversingASTNodeNotSpelledInSource || DeclNode->isImplicit();
   1208   bool ScopedChildren = TraversingASTChildrenNotSpelledInSource;
   1209 
   1210   if (const auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(DeclNode)) {
   1211     auto SK = CTSD->getSpecializationKind();
   1212     if (SK == TSK_ExplicitInstantiationDeclaration ||
   1213         SK == TSK_ExplicitInstantiationDefinition)
   1214       ScopedChildren = true;
   1215   } else if (const auto *FD = dyn_cast<FunctionDecl>(DeclNode)) {
   1216     if (FD->isDefaulted())
   1217       ScopedChildren = true;
   1218     if (FD->isTemplateInstantiation())
   1219       ScopedTraversal = true;
   1220   } else if (isa<BindingDecl>(DeclNode)) {
   1221     ScopedChildren = true;
   1222   }
   1223 
   1224   ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
   1225   ASTChildrenNotSpelledInSourceScope RAII2(this, ScopedChildren);
   1226 
   1227   match(*DeclNode);
   1228   return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
   1229 }
   1230 
   1231 bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue) {
   1232   if (!StmtNode) {
   1233     return true;
   1234   }
   1235   bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
   1236                          TraversingASTChildrenNotSpelledInSource;
   1237 
   1238   ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
   1239   match(*StmtNode);
   1240   return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode, Queue);
   1241 }
   1242 
   1243 bool MatchASTVisitor::TraverseType(QualType TypeNode) {
   1244   match(TypeNode);
   1245   return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
   1246 }
   1247 
   1248 bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLocNode) {
   1249   // The RecursiveASTVisitor only visits types if they're not within TypeLocs.
   1250   // We still want to find those types via matchers, so we match them here. Note
   1251   // that the TypeLocs are structurally a shadow-hierarchy to the expressed
   1252   // type, so we visit all involved parts of a compound type when matching on
   1253   // each TypeLoc.
   1254   match(TypeLocNode);
   1255   match(TypeLocNode.getType());
   1256   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTypeLoc(TypeLocNode);
   1257 }
   1258 
   1259 bool MatchASTVisitor::TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
   1260   match(*NNS);
   1261   return RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifier(NNS);
   1262 }
   1263 
   1264 bool MatchASTVisitor::TraverseNestedNameSpecifierLoc(
   1265     NestedNameSpecifierLoc NNS) {
   1266   if (!NNS)
   1267     return true;
   1268 
   1269   match(NNS);
   1270 
   1271   // We only match the nested name specifier here (as opposed to traversing it)
   1272   // because the traversal is already done in the parallel "Loc"-hierarchy.
   1273   if (NNS.hasQualifier())
   1274     match(*NNS.getNestedNameSpecifier());
   1275   return
   1276       RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifierLoc(NNS);
   1277 }
   1278 
   1279 bool MatchASTVisitor::TraverseConstructorInitializer(
   1280     CXXCtorInitializer *CtorInit) {
   1281   if (!CtorInit)
   1282     return true;
   1283 
   1284   bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
   1285                          TraversingASTChildrenNotSpelledInSource;
   1286 
   1287   if (!CtorInit->isWritten())
   1288     ScopedTraversal = true;
   1289 
   1290   ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
   1291 
   1292   match(*CtorInit);
   1293 
   1294   return RecursiveASTVisitor<MatchASTVisitor>::TraverseConstructorInitializer(
   1295       CtorInit);
   1296 }
   1297 
   1298 bool MatchASTVisitor::TraverseTemplateArgumentLoc(TemplateArgumentLoc Loc) {
   1299   match(Loc);
   1300   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateArgumentLoc(Loc);
   1301 }
   1302 
   1303 class MatchASTConsumer : public ASTConsumer {
   1304 public:
   1305   MatchASTConsumer(MatchFinder *Finder,
   1306                    MatchFinder::ParsingDoneTestCallback *ParsingDone)
   1307       : Finder(Finder), ParsingDone(ParsingDone) {}
   1308 
   1309 private:
   1310   void HandleTranslationUnit(ASTContext &Context) override {
   1311     if (ParsingDone != nullptr) {
   1312       ParsingDone->run();
   1313     }
   1314     Finder->matchAST(Context);
   1315   }
   1316 
   1317   MatchFinder *Finder;
   1318   MatchFinder::ParsingDoneTestCallback *ParsingDone;
   1319 };
   1320 
   1321 } // end namespace
   1322 } // end namespace internal
   1323 
   1324 MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
   1325                                       ASTContext *Context)
   1326   : Nodes(Nodes), Context(Context),
   1327     SourceManager(&Context->getSourceManager()) {}
   1328 
   1329 MatchFinder::MatchCallback::~MatchCallback() {}
   1330 MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
   1331 
   1332 MatchFinder::MatchFinder(MatchFinderOptions Options)
   1333     : Options(std::move(Options)), ParsingDone(nullptr) {}
   1334 
   1335 MatchFinder::~MatchFinder() {}
   1336 
   1337 void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
   1338                              MatchCallback *Action) {
   1339   llvm::Optional<TraversalKind> TK;
   1340   if (Action)
   1341     TK = Action->getCheckTraversalKind();
   1342   if (TK)
   1343     Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
   1344   else
   1345     Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
   1346   Matchers.AllCallbacks.insert(Action);
   1347 }
   1348 
   1349 void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
   1350                              MatchCallback *Action) {
   1351   Matchers.Type.emplace_back(NodeMatch, Action);
   1352   Matchers.AllCallbacks.insert(Action);
   1353 }
   1354 
   1355 void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
   1356                              MatchCallback *Action) {
   1357   llvm::Optional<TraversalKind> TK;
   1358   if (Action)
   1359     TK = Action->getCheckTraversalKind();
   1360   if (TK)
   1361     Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
   1362   else
   1363     Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
   1364   Matchers.AllCallbacks.insert(Action);
   1365 }
   1366 
   1367 void MatchFinder::addMatcher(const NestedNameSpecifierMatcher &NodeMatch,
   1368                              MatchCallback *Action) {
   1369   Matchers.NestedNameSpecifier.emplace_back(NodeMatch, Action);
   1370   Matchers.AllCallbacks.insert(Action);
   1371 }
   1372 
   1373 void MatchFinder::addMatcher(const NestedNameSpecifierLocMatcher &NodeMatch,
   1374                              MatchCallback *Action) {
   1375   Matchers.NestedNameSpecifierLoc.emplace_back(NodeMatch, Action);
   1376   Matchers.AllCallbacks.insert(Action);
   1377 }
   1378 
   1379 void MatchFinder::addMatcher(const TypeLocMatcher &NodeMatch,
   1380                              MatchCallback *Action) {
   1381   Matchers.TypeLoc.emplace_back(NodeMatch, Action);
   1382   Matchers.AllCallbacks.insert(Action);
   1383 }
   1384 
   1385 void MatchFinder::addMatcher(const CXXCtorInitializerMatcher &NodeMatch,
   1386                              MatchCallback *Action) {
   1387   Matchers.CtorInit.emplace_back(NodeMatch, Action);
   1388   Matchers.AllCallbacks.insert(Action);
   1389 }
   1390 
   1391 void MatchFinder::addMatcher(const TemplateArgumentLocMatcher &NodeMatch,
   1392                              MatchCallback *Action) {
   1393   Matchers.TemplateArgumentLoc.emplace_back(NodeMatch, Action);
   1394   Matchers.AllCallbacks.insert(Action);
   1395 }
   1396 
   1397 bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch,
   1398                                     MatchCallback *Action) {
   1399   if (NodeMatch.canConvertTo<Decl>()) {
   1400     addMatcher(NodeMatch.convertTo<Decl>(), Action);
   1401     return true;
   1402   } else if (NodeMatch.canConvertTo<QualType>()) {
   1403     addMatcher(NodeMatch.convertTo<QualType>(), Action);
   1404     return true;
   1405   } else if (NodeMatch.canConvertTo<Stmt>()) {
   1406     addMatcher(NodeMatch.convertTo<Stmt>(), Action);
   1407     return true;
   1408   } else if (NodeMatch.canConvertTo<NestedNameSpecifier>()) {
   1409     addMatcher(NodeMatch.convertTo<NestedNameSpecifier>(), Action);
   1410     return true;
   1411   } else if (NodeMatch.canConvertTo<NestedNameSpecifierLoc>()) {
   1412     addMatcher(NodeMatch.convertTo<NestedNameSpecifierLoc>(), Action);
   1413     return true;
   1414   } else if (NodeMatch.canConvertTo<TypeLoc>()) {
   1415     addMatcher(NodeMatch.convertTo<TypeLoc>(), Action);
   1416     return true;
   1417   } else if (NodeMatch.canConvertTo<CXXCtorInitializer>()) {
   1418     addMatcher(NodeMatch.convertTo<CXXCtorInitializer>(), Action);
   1419     return true;
   1420   } else if (NodeMatch.canConvertTo<TemplateArgumentLoc>()) {
   1421     addMatcher(NodeMatch.convertTo<TemplateArgumentLoc>(), Action);
   1422     return true;
   1423   }
   1424   return false;
   1425 }
   1426 
   1427 std::unique_ptr<ASTConsumer> MatchFinder::newASTConsumer() {
   1428   return std::make_unique<internal::MatchASTConsumer>(this, ParsingDone);
   1429 }
   1430 
   1431 void MatchFinder::match(const clang::DynTypedNode &Node, ASTContext &Context) {
   1432   internal::MatchASTVisitor Visitor(&Matchers, Options);
   1433   Visitor.set_active_ast_context(&Context);
   1434   Visitor.match(Node);
   1435 }
   1436 
   1437 void MatchFinder::matchAST(ASTContext &Context) {
   1438   internal::MatchASTVisitor Visitor(&Matchers, Options);
   1439   Visitor.set_active_ast_context(&Context);
   1440   Visitor.onStartOfTranslationUnit();
   1441   Visitor.TraverseAST(Context);
   1442   Visitor.onEndOfTranslationUnit();
   1443 }
   1444 
   1445 void MatchFinder::registerTestCallbackAfterParsing(
   1446     MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
   1447   ParsingDone = NewParsingDone;
   1448 }
   1449 
   1450 StringRef MatchFinder::MatchCallback::getID() const { return "<unknown>"; }
   1451 
   1452 llvm::Optional<TraversalKind>
   1453 MatchFinder::MatchCallback::getCheckTraversalKind() const {
   1454   return llvm::None;
   1455 }
   1456 
   1457 } // end namespace ast_matchers
   1458 } // end namespace clang
   1459