Home | History | Annotate | Line # | Download | only in Analysis
      1 //===- Consumed.cpp -------------------------------------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // A intra-procedural analysis for checking consumed properties.  This is based,
     10 // in part, on research on linear types.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "clang/Analysis/Analyses/Consumed.h"
     15 #include "clang/AST/Attr.h"
     16 #include "clang/AST/Decl.h"
     17 #include "clang/AST/DeclCXX.h"
     18 #include "clang/AST/Expr.h"
     19 #include "clang/AST/ExprCXX.h"
     20 #include "clang/AST/Stmt.h"
     21 #include "clang/AST/StmtVisitor.h"
     22 #include "clang/AST/Type.h"
     23 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
     24 #include "clang/Analysis/AnalysisDeclContext.h"
     25 #include "clang/Analysis/CFG.h"
     26 #include "clang/Basic/LLVM.h"
     27 #include "clang/Basic/OperatorKinds.h"
     28 #include "clang/Basic/SourceLocation.h"
     29 #include "llvm/ADT/DenseMap.h"
     30 #include "llvm/ADT/Optional.h"
     31 #include "llvm/ADT/STLExtras.h"
     32 #include "llvm/ADT/StringRef.h"
     33 #include "llvm/Support/Casting.h"
     34 #include "llvm/Support/ErrorHandling.h"
     35 #include <cassert>
     36 #include <memory>
     37 #include <utility>
     38 
     39 // TODO: Adjust states of args to constructors in the same way that arguments to
     40 //       function calls are handled.
     41 // TODO: Use information from tests in for- and while-loop conditional.
     42 // TODO: Add notes about the actual and expected state for
     43 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
     44 // TODO: Adjust the parser and AttributesList class to support lists of
     45 //       identifiers.
     46 // TODO: Warn about unreachable code.
     47 // TODO: Switch to using a bitmap to track unreachable blocks.
     48 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
     49 //       if (valid) ...; (Deferred)
     50 // TODO: Take notes on state transitions to provide better warning messages.
     51 //       (Deferred)
     52 // TODO: Test nested conditionals: A) Checking the same value multiple times,
     53 //       and 2) Checking different values. (Deferred)
     54 
     55 using namespace clang;
     56 using namespace consumed;
     57 
     58 // Key method definition
     59 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() = default;
     60 
     61 static SourceLocation getFirstStmtLoc(const CFGBlock *Block) {
     62   // Find the source location of the first statement in the block, if the block
     63   // is not empty.
     64   for (const auto &B : *Block)
     65     if (Optional<CFGStmt> CS = B.getAs<CFGStmt>())
     66       return CS->getStmt()->getBeginLoc();
     67 
     68   // Block is empty.
     69   // If we have one successor, return the first statement in that block
     70   if (Block->succ_size() == 1 && *Block->succ_begin())
     71     return getFirstStmtLoc(*Block->succ_begin());
     72 
     73   return {};
     74 }
     75 
     76 static SourceLocation getLastStmtLoc(const CFGBlock *Block) {
     77   // Find the source location of the last statement in the block, if the block
     78   // is not empty.
     79   if (const Stmt *StmtNode = Block->getTerminatorStmt()) {
     80     return StmtNode->getBeginLoc();
     81   } else {
     82     for (CFGBlock::const_reverse_iterator BI = Block->rbegin(),
     83          BE = Block->rend(); BI != BE; ++BI) {
     84       if (Optional<CFGStmt> CS = BI->getAs<CFGStmt>())
     85         return CS->getStmt()->getBeginLoc();
     86     }
     87   }
     88 
     89   // If we have one successor, return the first statement in that block
     90   SourceLocation Loc;
     91   if (Block->succ_size() == 1 && *Block->succ_begin())
     92     Loc = getFirstStmtLoc(*Block->succ_begin());
     93   if (Loc.isValid())
     94     return Loc;
     95 
     96   // If we have one predecessor, return the last statement in that block
     97   if (Block->pred_size() == 1 && *Block->pred_begin())
     98     return getLastStmtLoc(*Block->pred_begin());
     99 
    100   return Loc;
    101 }
    102 
    103 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
    104   switch (State) {
    105   case CS_Unconsumed:
    106     return CS_Consumed;
    107   case CS_Consumed:
    108     return CS_Unconsumed;
    109   case CS_None:
    110     return CS_None;
    111   case CS_Unknown:
    112     return CS_Unknown;
    113   }
    114   llvm_unreachable("invalid enum");
    115 }
    116 
    117 static bool isCallableInState(const CallableWhenAttr *CWAttr,
    118                               ConsumedState State) {
    119   for (const auto &S : CWAttr->callableStates()) {
    120     ConsumedState MappedAttrState = CS_None;
    121 
    122     switch (S) {
    123     case CallableWhenAttr::Unknown:
    124       MappedAttrState = CS_Unknown;
    125       break;
    126 
    127     case CallableWhenAttr::Unconsumed:
    128       MappedAttrState = CS_Unconsumed;
    129       break;
    130 
    131     case CallableWhenAttr::Consumed:
    132       MappedAttrState = CS_Consumed;
    133       break;
    134     }
    135 
    136     if (MappedAttrState == State)
    137       return true;
    138   }
    139 
    140   return false;
    141 }
    142 
    143 static bool isConsumableType(const QualType &QT) {
    144   if (QT->isPointerType() || QT->isReferenceType())
    145     return false;
    146 
    147   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
    148     return RD->hasAttr<ConsumableAttr>();
    149 
    150   return false;
    151 }
    152 
    153 static bool isAutoCastType(const QualType &QT) {
    154   if (QT->isPointerType() || QT->isReferenceType())
    155     return false;
    156 
    157   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
    158     return RD->hasAttr<ConsumableAutoCastAttr>();
    159 
    160   return false;
    161 }
    162 
    163 static bool isSetOnReadPtrType(const QualType &QT) {
    164   if (const CXXRecordDecl *RD = QT->getPointeeCXXRecordDecl())
    165     return RD->hasAttr<ConsumableSetOnReadAttr>();
    166   return false;
    167 }
    168 
    169 static bool isKnownState(ConsumedState State) {
    170   switch (State) {
    171   case CS_Unconsumed:
    172   case CS_Consumed:
    173     return true;
    174   case CS_None:
    175   case CS_Unknown:
    176     return false;
    177   }
    178   llvm_unreachable("invalid enum");
    179 }
    180 
    181 static bool isRValueRef(QualType ParamType) {
    182   return ParamType->isRValueReferenceType();
    183 }
    184 
    185 static bool isTestingFunction(const FunctionDecl *FunDecl) {
    186   return FunDecl->hasAttr<TestTypestateAttr>();
    187 }
    188 
    189 static bool isPointerOrRef(QualType ParamType) {
    190   return ParamType->isPointerType() || ParamType->isReferenceType();
    191 }
    192 
    193 static ConsumedState mapConsumableAttrState(const QualType QT) {
    194   assert(isConsumableType(QT));
    195 
    196   const ConsumableAttr *CAttr =
    197       QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
    198 
    199   switch (CAttr->getDefaultState()) {
    200   case ConsumableAttr::Unknown:
    201     return CS_Unknown;
    202   case ConsumableAttr::Unconsumed:
    203     return CS_Unconsumed;
    204   case ConsumableAttr::Consumed:
    205     return CS_Consumed;
    206   }
    207   llvm_unreachable("invalid enum");
    208 }
    209 
    210 static ConsumedState
    211 mapParamTypestateAttrState(const ParamTypestateAttr *PTAttr) {
    212   switch (PTAttr->getParamState()) {
    213   case ParamTypestateAttr::Unknown:
    214     return CS_Unknown;
    215   case ParamTypestateAttr::Unconsumed:
    216     return CS_Unconsumed;
    217   case ParamTypestateAttr::Consumed:
    218     return CS_Consumed;
    219   }
    220   llvm_unreachable("invalid_enum");
    221 }
    222 
    223 static ConsumedState
    224 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
    225   switch (RTSAttr->getState()) {
    226   case ReturnTypestateAttr::Unknown:
    227     return CS_Unknown;
    228   case ReturnTypestateAttr::Unconsumed:
    229     return CS_Unconsumed;
    230   case ReturnTypestateAttr::Consumed:
    231     return CS_Consumed;
    232   }
    233   llvm_unreachable("invalid enum");
    234 }
    235 
    236 static ConsumedState mapSetTypestateAttrState(const SetTypestateAttr *STAttr) {
    237   switch (STAttr->getNewState()) {
    238   case SetTypestateAttr::Unknown:
    239     return CS_Unknown;
    240   case SetTypestateAttr::Unconsumed:
    241     return CS_Unconsumed;
    242   case SetTypestateAttr::Consumed:
    243     return CS_Consumed;
    244   }
    245   llvm_unreachable("invalid_enum");
    246 }
    247 
    248 static StringRef stateToString(ConsumedState State) {
    249   switch (State) {
    250   case consumed::CS_None:
    251     return "none";
    252 
    253   case consumed::CS_Unknown:
    254     return "unknown";
    255 
    256   case consumed::CS_Unconsumed:
    257     return "unconsumed";
    258 
    259   case consumed::CS_Consumed:
    260     return "consumed";
    261   }
    262   llvm_unreachable("invalid enum");
    263 }
    264 
    265 static ConsumedState testsFor(const FunctionDecl *FunDecl) {
    266   assert(isTestingFunction(FunDecl));
    267   switch (FunDecl->getAttr<TestTypestateAttr>()->getTestState()) {
    268   case TestTypestateAttr::Unconsumed:
    269     return CS_Unconsumed;
    270   case TestTypestateAttr::Consumed:
    271     return CS_Consumed;
    272   }
    273   llvm_unreachable("invalid enum");
    274 }
    275 
    276 namespace {
    277 
    278 struct VarTestResult {
    279   const VarDecl *Var;
    280   ConsumedState TestsFor;
    281 };
    282 
    283 } // namespace
    284 
    285 namespace clang {
    286 namespace consumed {
    287 
    288 enum EffectiveOp {
    289   EO_And,
    290   EO_Or
    291 };
    292 
    293 class PropagationInfo {
    294   enum {
    295     IT_None,
    296     IT_State,
    297     IT_VarTest,
    298     IT_BinTest,
    299     IT_Var,
    300     IT_Tmp
    301   } InfoType = IT_None;
    302 
    303   struct BinTestTy {
    304     const BinaryOperator *Source;
    305     EffectiveOp EOp;
    306     VarTestResult LTest;
    307     VarTestResult RTest;
    308   };
    309 
    310   union {
    311     ConsumedState State;
    312     VarTestResult VarTest;
    313     const VarDecl *Var;
    314     const CXXBindTemporaryExpr *Tmp;
    315     BinTestTy BinTest;
    316   };
    317 
    318 public:
    319   PropagationInfo() = default;
    320   PropagationInfo(const VarTestResult &VarTest)
    321       : InfoType(IT_VarTest), VarTest(VarTest) {}
    322 
    323   PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
    324       : InfoType(IT_VarTest) {
    325     VarTest.Var      = Var;
    326     VarTest.TestsFor = TestsFor;
    327   }
    328 
    329   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
    330                   const VarTestResult &LTest, const VarTestResult &RTest)
    331       : InfoType(IT_BinTest) {
    332     BinTest.Source  = Source;
    333     BinTest.EOp     = EOp;
    334     BinTest.LTest   = LTest;
    335     BinTest.RTest   = RTest;
    336   }
    337 
    338   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
    339                   const VarDecl *LVar, ConsumedState LTestsFor,
    340                   const VarDecl *RVar, ConsumedState RTestsFor)
    341       : InfoType(IT_BinTest) {
    342     BinTest.Source         = Source;
    343     BinTest.EOp            = EOp;
    344     BinTest.LTest.Var      = LVar;
    345     BinTest.LTest.TestsFor = LTestsFor;
    346     BinTest.RTest.Var      = RVar;
    347     BinTest.RTest.TestsFor = RTestsFor;
    348   }
    349 
    350   PropagationInfo(ConsumedState State)
    351       : InfoType(IT_State), State(State) {}
    352   PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
    353   PropagationInfo(const CXXBindTemporaryExpr *Tmp)
    354       : InfoType(IT_Tmp), Tmp(Tmp) {}
    355 
    356   const ConsumedState &getState() const {
    357     assert(InfoType == IT_State);
    358     return State;
    359   }
    360 
    361   const VarTestResult &getVarTest() const {
    362     assert(InfoType == IT_VarTest);
    363     return VarTest;
    364   }
    365 
    366   const VarTestResult &getLTest() const {
    367     assert(InfoType == IT_BinTest);
    368     return BinTest.LTest;
    369   }
    370 
    371   const VarTestResult &getRTest() const {
    372     assert(InfoType == IT_BinTest);
    373     return BinTest.RTest;
    374   }
    375 
    376   const VarDecl *getVar() const {
    377     assert(InfoType == IT_Var);
    378     return Var;
    379   }
    380 
    381   const CXXBindTemporaryExpr *getTmp() const {
    382     assert(InfoType == IT_Tmp);
    383     return Tmp;
    384   }
    385 
    386   ConsumedState getAsState(const ConsumedStateMap *StateMap) const {
    387     assert(isVar() || isTmp() || isState());
    388 
    389     if (isVar())
    390       return StateMap->getState(Var);
    391     else if (isTmp())
    392       return StateMap->getState(Tmp);
    393     else if (isState())
    394       return State;
    395     else
    396       return CS_None;
    397   }
    398 
    399   EffectiveOp testEffectiveOp() const {
    400     assert(InfoType == IT_BinTest);
    401     return BinTest.EOp;
    402   }
    403 
    404   const BinaryOperator * testSourceNode() const {
    405     assert(InfoType == IT_BinTest);
    406     return BinTest.Source;
    407   }
    408 
    409   bool isValid() const { return InfoType != IT_None; }
    410   bool isState() const { return InfoType == IT_State; }
    411   bool isVarTest() const { return InfoType == IT_VarTest; }
    412   bool isBinTest() const { return InfoType == IT_BinTest; }
    413   bool isVar() const { return InfoType == IT_Var; }
    414   bool isTmp() const { return InfoType == IT_Tmp; }
    415 
    416   bool isTest() const {
    417     return InfoType == IT_VarTest || InfoType == IT_BinTest;
    418   }
    419 
    420   bool isPointerToValue() const {
    421     return InfoType == IT_Var || InfoType == IT_Tmp;
    422   }
    423 
    424   PropagationInfo invertTest() const {
    425     assert(InfoType == IT_VarTest || InfoType == IT_BinTest);
    426 
    427     if (InfoType == IT_VarTest) {
    428       return PropagationInfo(VarTest.Var,
    429                              invertConsumedUnconsumed(VarTest.TestsFor));
    430 
    431     } else if (InfoType == IT_BinTest) {
    432       return PropagationInfo(BinTest.Source,
    433         BinTest.EOp == EO_And ? EO_Or : EO_And,
    434         BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
    435         BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
    436     } else {
    437       return {};
    438     }
    439   }
    440 };
    441 
    442 } // namespace consumed
    443 } // namespace clang
    444 
    445 static void
    446 setStateForVarOrTmp(ConsumedStateMap *StateMap, const PropagationInfo &PInfo,
    447                     ConsumedState State) {
    448   assert(PInfo.isVar() || PInfo.isTmp());
    449 
    450   if (PInfo.isVar())
    451     StateMap->setState(PInfo.getVar(), State);
    452   else
    453     StateMap->setState(PInfo.getTmp(), State);
    454 }
    455 
    456 namespace clang {
    457 namespace consumed {
    458 
    459 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
    460   using MapType = llvm::DenseMap<const Stmt *, PropagationInfo>;
    461   using PairType= std::pair<const Stmt *, PropagationInfo>;
    462   using InfoEntry = MapType::iterator;
    463   using ConstInfoEntry = MapType::const_iterator;
    464 
    465   ConsumedAnalyzer &Analyzer;
    466   ConsumedStateMap *StateMap;
    467   MapType PropagationMap;
    468 
    469   InfoEntry findInfo(const Expr *E) {
    470     if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
    471       if (!Cleanups->cleanupsHaveSideEffects())
    472         E = Cleanups->getSubExpr();
    473     return PropagationMap.find(E->IgnoreParens());
    474   }
    475 
    476   ConstInfoEntry findInfo(const Expr *E) const {
    477     if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
    478       if (!Cleanups->cleanupsHaveSideEffects())
    479         E = Cleanups->getSubExpr();
    480     return PropagationMap.find(E->IgnoreParens());
    481   }
    482 
    483   void insertInfo(const Expr *E, const PropagationInfo &PI) {
    484     PropagationMap.insert(PairType(E->IgnoreParens(), PI));
    485   }
    486 
    487   void forwardInfo(const Expr *From, const Expr *To);
    488   void copyInfo(const Expr *From, const Expr *To, ConsumedState CS);
    489   ConsumedState getInfo(const Expr *From);
    490   void setInfo(const Expr *To, ConsumedState NS);
    491   void propagateReturnType(const Expr *Call, const FunctionDecl *Fun);
    492 
    493 public:
    494   void checkCallability(const PropagationInfo &PInfo,
    495                         const FunctionDecl *FunDecl,
    496                         SourceLocation BlameLoc);
    497   bool handleCall(const CallExpr *Call, const Expr *ObjArg,
    498                   const FunctionDecl *FunD);
    499 
    500   void VisitBinaryOperator(const BinaryOperator *BinOp);
    501   void VisitCallExpr(const CallExpr *Call);
    502   void VisitCastExpr(const CastExpr *Cast);
    503   void VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *Temp);
    504   void VisitCXXConstructExpr(const CXXConstructExpr *Call);
    505   void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
    506   void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
    507   void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
    508   void VisitDeclStmt(const DeclStmt *DelcS);
    509   void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
    510   void VisitMemberExpr(const MemberExpr *MExpr);
    511   void VisitParmVarDecl(const ParmVarDecl *Param);
    512   void VisitReturnStmt(const ReturnStmt *Ret);
    513   void VisitUnaryOperator(const UnaryOperator *UOp);
    514   void VisitVarDecl(const VarDecl *Var);
    515 
    516   ConsumedStmtVisitor(ConsumedAnalyzer &Analyzer, ConsumedStateMap *StateMap)
    517       : Analyzer(Analyzer), StateMap(StateMap) {}
    518 
    519   PropagationInfo getInfo(const Expr *StmtNode) const {
    520     ConstInfoEntry Entry = findInfo(StmtNode);
    521 
    522     if (Entry != PropagationMap.end())
    523       return Entry->second;
    524     else
    525       return {};
    526   }
    527 
    528   void reset(ConsumedStateMap *NewStateMap) {
    529     StateMap = NewStateMap;
    530   }
    531 };
    532 
    533 } // namespace consumed
    534 } // namespace clang
    535 
    536 void ConsumedStmtVisitor::forwardInfo(const Expr *From, const Expr *To) {
    537   InfoEntry Entry = findInfo(From);
    538   if (Entry != PropagationMap.end())
    539     insertInfo(To, Entry->second);
    540 }
    541 
    542 // Create a new state for To, which is initialized to the state of From.
    543 // If NS is not CS_None, sets the state of From to NS.
    544 void ConsumedStmtVisitor::copyInfo(const Expr *From, const Expr *To,
    545                                    ConsumedState NS) {
    546   InfoEntry Entry = findInfo(From);
    547   if (Entry != PropagationMap.end()) {
    548     PropagationInfo& PInfo = Entry->second;
    549     ConsumedState CS = PInfo.getAsState(StateMap);
    550     if (CS != CS_None)
    551       insertInfo(To, PropagationInfo(CS));
    552     if (NS != CS_None && PInfo.isPointerToValue())
    553       setStateForVarOrTmp(StateMap, PInfo, NS);
    554   }
    555 }
    556 
    557 // Get the ConsumedState for From
    558 ConsumedState ConsumedStmtVisitor::getInfo(const Expr *From) {
    559   InfoEntry Entry = findInfo(From);
    560   if (Entry != PropagationMap.end()) {
    561     PropagationInfo& PInfo = Entry->second;
    562     return PInfo.getAsState(StateMap);
    563   }
    564   return CS_None;
    565 }
    566 
    567 // If we already have info for To then update it, otherwise create a new entry.
    568 void ConsumedStmtVisitor::setInfo(const Expr *To, ConsumedState NS) {
    569   InfoEntry Entry = findInfo(To);
    570   if (Entry != PropagationMap.end()) {
    571     PropagationInfo& PInfo = Entry->second;
    572     if (PInfo.isPointerToValue())
    573       setStateForVarOrTmp(StateMap, PInfo, NS);
    574   } else if (NS != CS_None) {
    575      insertInfo(To, PropagationInfo(NS));
    576   }
    577 }
    578 
    579 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
    580                                            const FunctionDecl *FunDecl,
    581                                            SourceLocation BlameLoc) {
    582   assert(!PInfo.isTest());
    583 
    584   const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
    585   if (!CWAttr)
    586     return;
    587 
    588   if (PInfo.isVar()) {
    589     ConsumedState VarState = StateMap->getState(PInfo.getVar());
    590 
    591     if (VarState == CS_None || isCallableInState(CWAttr, VarState))
    592       return;
    593 
    594     Analyzer.WarningsHandler.warnUseInInvalidState(
    595       FunDecl->getNameAsString(), PInfo.getVar()->getNameAsString(),
    596       stateToString(VarState), BlameLoc);
    597   } else {
    598     ConsumedState TmpState = PInfo.getAsState(StateMap);
    599 
    600     if (TmpState == CS_None || isCallableInState(CWAttr, TmpState))
    601       return;
    602 
    603     Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
    604       FunDecl->getNameAsString(), stateToString(TmpState), BlameLoc);
    605   }
    606 }
    607 
    608 // Factors out common behavior for function, method, and operator calls.
    609 // Check parameters and set parameter state if necessary.
    610 // Returns true if the state of ObjArg is set, or false otherwise.
    611 bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,
    612                                      const FunctionDecl *FunD) {
    613   unsigned Offset = 0;
    614   if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))
    615     Offset = 1;  // first argument is 'this'
    616 
    617   // check explicit parameters
    618   for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
    619     // Skip variable argument lists.
    620     if (Index - Offset >= FunD->getNumParams())
    621       break;
    622 
    623     const ParmVarDecl *Param = FunD->getParamDecl(Index - Offset);
    624     QualType ParamType = Param->getType();
    625 
    626     InfoEntry Entry = findInfo(Call->getArg(Index));
    627 
    628     if (Entry == PropagationMap.end() || Entry->second.isTest())
    629       continue;
    630     PropagationInfo PInfo = Entry->second;
    631 
    632     // Check that the parameter is in the correct state.
    633     if (ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>()) {
    634       ConsumedState ParamState = PInfo.getAsState(StateMap);
    635       ConsumedState ExpectedState = mapParamTypestateAttrState(PTA);
    636 
    637       if (ParamState != ExpectedState)
    638         Analyzer.WarningsHandler.warnParamTypestateMismatch(
    639           Call->getArg(Index)->getExprLoc(),
    640           stateToString(ExpectedState), stateToString(ParamState));
    641     }
    642 
    643     if (!(Entry->second.isVar() || Entry->second.isTmp()))
    644       continue;
    645 
    646     // Adjust state on the caller side.
    647     if (ReturnTypestateAttr *RT = Param->getAttr<ReturnTypestateAttr>())
    648       setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));
    649     else if (isRValueRef(ParamType) || isConsumableType(ParamType))
    650       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);
    651     else if (isPointerOrRef(ParamType) &&
    652              (!ParamType->getPointeeType().isConstQualified() ||
    653               isSetOnReadPtrType(ParamType)))
    654       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);
    655   }
    656 
    657   if (!ObjArg)
    658     return false;
    659 
    660   // check implicit 'self' parameter, if present
    661   InfoEntry Entry = findInfo(ObjArg);
    662   if (Entry != PropagationMap.end()) {
    663     PropagationInfo PInfo = Entry->second;
    664     checkCallability(PInfo, FunD, Call->getExprLoc());
    665 
    666     if (SetTypestateAttr *STA = FunD->getAttr<SetTypestateAttr>()) {
    667       if (PInfo.isVar()) {
    668         StateMap->setState(PInfo.getVar(), mapSetTypestateAttrState(STA));
    669         return true;
    670       }
    671       else if (PInfo.isTmp()) {
    672         StateMap->setState(PInfo.getTmp(), mapSetTypestateAttrState(STA));
    673         return true;
    674       }
    675     }
    676     else if (isTestingFunction(FunD) && PInfo.isVar()) {
    677       PropagationMap.insert(PairType(Call,
    678         PropagationInfo(PInfo.getVar(), testsFor(FunD))));
    679     }
    680   }
    681   return false;
    682 }
    683 
    684 void ConsumedStmtVisitor::propagateReturnType(const Expr *Call,
    685                                               const FunctionDecl *Fun) {
    686   QualType RetType = Fun->getCallResultType();
    687   if (RetType->isReferenceType())
    688     RetType = RetType->getPointeeType();
    689 
    690   if (isConsumableType(RetType)) {
    691     ConsumedState ReturnState;
    692     if (ReturnTypestateAttr *RTA = Fun->getAttr<ReturnTypestateAttr>())
    693       ReturnState = mapReturnTypestateAttrState(RTA);
    694     else
    695       ReturnState = mapConsumableAttrState(RetType);
    696 
    697     PropagationMap.insert(PairType(Call, PropagationInfo(ReturnState)));
    698   }
    699 }
    700 
    701 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
    702   switch (BinOp->getOpcode()) {
    703   case BO_LAnd:
    704   case BO_LOr : {
    705     InfoEntry LEntry = findInfo(BinOp->getLHS()),
    706               REntry = findInfo(BinOp->getRHS());
    707 
    708     VarTestResult LTest, RTest;
    709 
    710     if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {
    711       LTest = LEntry->second.getVarTest();
    712     } else {
    713       LTest.Var      = nullptr;
    714       LTest.TestsFor = CS_None;
    715     }
    716 
    717     if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {
    718       RTest = REntry->second.getVarTest();
    719     } else {
    720       RTest.Var      = nullptr;
    721       RTest.TestsFor = CS_None;
    722     }
    723 
    724     if (!(LTest.Var == nullptr && RTest.Var == nullptr))
    725       PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
    726         static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
    727     break;
    728   }
    729 
    730   case BO_PtrMemD:
    731   case BO_PtrMemI:
    732     forwardInfo(BinOp->getLHS(), BinOp);
    733     break;
    734 
    735   default:
    736     break;
    737   }
    738 }
    739 
    740 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
    741   const FunctionDecl *FunDecl = Call->getDirectCallee();
    742   if (!FunDecl)
    743     return;
    744 
    745   // Special case for the std::move function.
    746   // TODO: Make this more specific. (Deferred)
    747   if (Call->isCallToStdMove()) {
    748     copyInfo(Call->getArg(0), Call, CS_Consumed);
    749     return;
    750   }
    751 
    752   handleCall(Call, nullptr, FunDecl);
    753   propagateReturnType(Call, FunDecl);
    754 }
    755 
    756 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
    757   forwardInfo(Cast->getSubExpr(), Cast);
    758 }
    759 
    760 void ConsumedStmtVisitor::VisitCXXBindTemporaryExpr(
    761   const CXXBindTemporaryExpr *Temp) {
    762 
    763   InfoEntry Entry = findInfo(Temp->getSubExpr());
    764 
    765   if (Entry != PropagationMap.end() && !Entry->second.isTest()) {
    766     StateMap->setState(Temp, Entry->second.getAsState(StateMap));
    767     PropagationMap.insert(PairType(Temp, PropagationInfo(Temp)));
    768   }
    769 }
    770 
    771 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
    772   CXXConstructorDecl *Constructor = Call->getConstructor();
    773 
    774   QualType ThisType = Constructor->getThisType()->getPointeeType();
    775 
    776   if (!isConsumableType(ThisType))
    777     return;
    778 
    779   // FIXME: What should happen if someone annotates the move constructor?
    780   if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {
    781     // TODO: Adjust state of args appropriately.
    782     ConsumedState RetState = mapReturnTypestateAttrState(RTA);
    783     PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
    784   } else if (Constructor->isDefaultConstructor()) {
    785     PropagationMap.insert(PairType(Call,
    786       PropagationInfo(consumed::CS_Consumed)));
    787   } else if (Constructor->isMoveConstructor()) {
    788     copyInfo(Call->getArg(0), Call, CS_Consumed);
    789   } else if (Constructor->isCopyConstructor()) {
    790     // Copy state from arg.  If setStateOnRead then set arg to CS_Unknown.
    791     ConsumedState NS =
    792       isSetOnReadPtrType(Constructor->getThisType()) ?
    793       CS_Unknown : CS_None;
    794     copyInfo(Call->getArg(0), Call, NS);
    795   } else {
    796     // TODO: Adjust state of args appropriately.
    797     ConsumedState RetState = mapConsumableAttrState(ThisType);
    798     PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
    799   }
    800 }
    801 
    802 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
    803     const CXXMemberCallExpr *Call) {
    804   CXXMethodDecl* MD = Call->getMethodDecl();
    805   if (!MD)
    806     return;
    807 
    808   handleCall(Call, Call->getImplicitObjectArgument(), MD);
    809   propagateReturnType(Call, MD);
    810 }
    811 
    812 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
    813     const CXXOperatorCallExpr *Call) {
    814   const auto *FunDecl = dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
    815   if (!FunDecl) return;
    816 
    817   if (Call->getOperator() == OO_Equal) {
    818     ConsumedState CS = getInfo(Call->getArg(1));
    819     if (!handleCall(Call, Call->getArg(0), FunDecl))
    820       setInfo(Call->getArg(0), CS);
    821     return;
    822   }
    823 
    824   if (const auto *MCall = dyn_cast<CXXMemberCallExpr>(Call))
    825     handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);
    826   else
    827     handleCall(Call, Call->getArg(0), FunDecl);
    828 
    829   propagateReturnType(Call, FunDecl);
    830 }
    831 
    832 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
    833   if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
    834     if (StateMap->getState(Var) != consumed::CS_None)
    835       PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
    836 }
    837 
    838 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
    839   for (const auto *DI : DeclS->decls())
    840     if (isa<VarDecl>(DI))
    841       VisitVarDecl(cast<VarDecl>(DI));
    842 
    843   if (DeclS->isSingleDecl())
    844     if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
    845       PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
    846 }
    847 
    848 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
    849   const MaterializeTemporaryExpr *Temp) {
    850   forwardInfo(Temp->getSubExpr(), Temp);
    851 }
    852 
    853 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
    854   forwardInfo(MExpr->getBase(), MExpr);
    855 }
    856 
    857 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
    858   QualType ParamType = Param->getType();
    859   ConsumedState ParamState = consumed::CS_None;
    860 
    861   if (const ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>())
    862     ParamState = mapParamTypestateAttrState(PTA);
    863   else if (isConsumableType(ParamType))
    864     ParamState = mapConsumableAttrState(ParamType);
    865   else if (isRValueRef(ParamType) &&
    866            isConsumableType(ParamType->getPointeeType()))
    867     ParamState = mapConsumableAttrState(ParamType->getPointeeType());
    868   else if (ParamType->isReferenceType() &&
    869            isConsumableType(ParamType->getPointeeType()))
    870     ParamState = consumed::CS_Unknown;
    871 
    872   if (ParamState != CS_None)
    873     StateMap->setState(Param, ParamState);
    874 }
    875 
    876 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
    877   ConsumedState ExpectedState = Analyzer.getExpectedReturnState();
    878 
    879   if (ExpectedState != CS_None) {
    880     InfoEntry Entry = findInfo(Ret->getRetValue());
    881 
    882     if (Entry != PropagationMap.end()) {
    883       ConsumedState RetState = Entry->second.getAsState(StateMap);
    884 
    885       if (RetState != ExpectedState)
    886         Analyzer.WarningsHandler.warnReturnTypestateMismatch(
    887           Ret->getReturnLoc(), stateToString(ExpectedState),
    888           stateToString(RetState));
    889     }
    890   }
    891 
    892   StateMap->checkParamsForReturnTypestate(Ret->getBeginLoc(),
    893                                           Analyzer.WarningsHandler);
    894 }
    895 
    896 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
    897   InfoEntry Entry = findInfo(UOp->getSubExpr());
    898   if (Entry == PropagationMap.end()) return;
    899 
    900   switch (UOp->getOpcode()) {
    901   case UO_AddrOf:
    902     PropagationMap.insert(PairType(UOp, Entry->second));
    903     break;
    904 
    905   case UO_LNot:
    906     if (Entry->second.isTest())
    907       PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
    908     break;
    909 
    910   default:
    911     break;
    912   }
    913 }
    914 
    915 // TODO: See if I need to check for reference types here.
    916 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
    917   if (isConsumableType(Var->getType())) {
    918     if (Var->hasInit()) {
    919       MapType::iterator VIT = findInfo(Var->getInit()->IgnoreImplicit());
    920       if (VIT != PropagationMap.end()) {
    921         PropagationInfo PInfo = VIT->second;
    922         ConsumedState St = PInfo.getAsState(StateMap);
    923 
    924         if (St != consumed::CS_None) {
    925           StateMap->setState(Var, St);
    926           return;
    927         }
    928       }
    929     }
    930     // Otherwise
    931     StateMap->setState(Var, consumed::CS_Unknown);
    932   }
    933 }
    934 
    935 static void splitVarStateForIf(const IfStmt *IfNode, const VarTestResult &Test,
    936                                ConsumedStateMap *ThenStates,
    937                                ConsumedStateMap *ElseStates) {
    938   ConsumedState VarState = ThenStates->getState(Test.Var);
    939 
    940   if (VarState == CS_Unknown) {
    941     ThenStates->setState(Test.Var, Test.TestsFor);
    942     ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
    943   } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
    944     ThenStates->markUnreachable();
    945   } else if (VarState == Test.TestsFor) {
    946     ElseStates->markUnreachable();
    947   }
    948 }
    949 
    950 static void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
    951                                     ConsumedStateMap *ThenStates,
    952                                     ConsumedStateMap *ElseStates) {
    953   const VarTestResult &LTest = PInfo.getLTest(),
    954                       &RTest = PInfo.getRTest();
    955 
    956   ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
    957                 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
    958 
    959   if (LTest.Var) {
    960     if (PInfo.testEffectiveOp() == EO_And) {
    961       if (LState == CS_Unknown) {
    962         ThenStates->setState(LTest.Var, LTest.TestsFor);
    963       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
    964         ThenStates->markUnreachable();
    965       } else if (LState == LTest.TestsFor && isKnownState(RState)) {
    966         if (RState == RTest.TestsFor)
    967           ElseStates->markUnreachable();
    968         else
    969           ThenStates->markUnreachable();
    970       }
    971     } else {
    972       if (LState == CS_Unknown) {
    973         ElseStates->setState(LTest.Var,
    974                              invertConsumedUnconsumed(LTest.TestsFor));
    975       } else if (LState == LTest.TestsFor) {
    976         ElseStates->markUnreachable();
    977       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
    978                  isKnownState(RState)) {
    979         if (RState == RTest.TestsFor)
    980           ElseStates->markUnreachable();
    981         else
    982           ThenStates->markUnreachable();
    983       }
    984     }
    985   }
    986 
    987   if (RTest.Var) {
    988     if (PInfo.testEffectiveOp() == EO_And) {
    989       if (RState == CS_Unknown)
    990         ThenStates->setState(RTest.Var, RTest.TestsFor);
    991       else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
    992         ThenStates->markUnreachable();
    993     } else {
    994       if (RState == CS_Unknown)
    995         ElseStates->setState(RTest.Var,
    996                              invertConsumedUnconsumed(RTest.TestsFor));
    997       else if (RState == RTest.TestsFor)
    998         ElseStates->markUnreachable();
    999     }
   1000   }
   1001 }
   1002 
   1003 bool ConsumedBlockInfo::allBackEdgesVisited(const CFGBlock *CurrBlock,
   1004                                             const CFGBlock *TargetBlock) {
   1005   assert(CurrBlock && "Block pointer must not be NULL");
   1006   assert(TargetBlock && "TargetBlock pointer must not be NULL");
   1007 
   1008   unsigned int CurrBlockOrder = VisitOrder[CurrBlock->getBlockID()];
   1009   for (CFGBlock::const_pred_iterator PI = TargetBlock->pred_begin(),
   1010        PE = TargetBlock->pred_end(); PI != PE; ++PI) {
   1011     if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )
   1012       return false;
   1013   }
   1014   return true;
   1015 }
   1016 
   1017 void ConsumedBlockInfo::addInfo(
   1018     const CFGBlock *Block, ConsumedStateMap *StateMap,
   1019     std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {
   1020   assert(Block && "Block pointer must not be NULL");
   1021 
   1022   auto &Entry = StateMapsArray[Block->getBlockID()];
   1023 
   1024   if (Entry) {
   1025     Entry->intersect(*StateMap);
   1026   } else if (OwnedStateMap)
   1027     Entry = std::move(OwnedStateMap);
   1028   else
   1029     Entry = std::make_unique<ConsumedStateMap>(*StateMap);
   1030 }
   1031 
   1032 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
   1033                                 std::unique_ptr<ConsumedStateMap> StateMap) {
   1034   assert(Block && "Block pointer must not be NULL");
   1035 
   1036   auto &Entry = StateMapsArray[Block->getBlockID()];
   1037 
   1038   if (Entry) {
   1039     Entry->intersect(*StateMap);
   1040   } else {
   1041     Entry = std::move(StateMap);
   1042   }
   1043 }
   1044 
   1045 ConsumedStateMap* ConsumedBlockInfo::borrowInfo(const CFGBlock *Block) {
   1046   assert(Block && "Block pointer must not be NULL");
   1047   assert(StateMapsArray[Block->getBlockID()] && "Block has no block info");
   1048 
   1049   return StateMapsArray[Block->getBlockID()].get();
   1050 }
   1051 
   1052 void ConsumedBlockInfo::discardInfo(const CFGBlock *Block) {
   1053   StateMapsArray[Block->getBlockID()] = nullptr;
   1054 }
   1055 
   1056 std::unique_ptr<ConsumedStateMap>
   1057 ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
   1058   assert(Block && "Block pointer must not be NULL");
   1059 
   1060   auto &Entry = StateMapsArray[Block->getBlockID()];
   1061   return isBackEdgeTarget(Block) ? std::make_unique<ConsumedStateMap>(*Entry)
   1062                                  : std::move(Entry);
   1063 }
   1064 
   1065 bool ConsumedBlockInfo::isBackEdge(const CFGBlock *From, const CFGBlock *To) {
   1066   assert(From && "From block must not be NULL");
   1067   assert(To   && "From block must not be NULL");
   1068 
   1069   return VisitOrder[From->getBlockID()] > VisitOrder[To->getBlockID()];
   1070 }
   1071 
   1072 bool ConsumedBlockInfo::isBackEdgeTarget(const CFGBlock *Block) {
   1073   assert(Block && "Block pointer must not be NULL");
   1074 
   1075   // Anything with less than two predecessors can't be the target of a back
   1076   // edge.
   1077   if (Block->pred_size() < 2)
   1078     return false;
   1079 
   1080   unsigned int BlockVisitOrder = VisitOrder[Block->getBlockID()];
   1081   for (CFGBlock::const_pred_iterator PI = Block->pred_begin(),
   1082        PE = Block->pred_end(); PI != PE; ++PI) {
   1083     if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])
   1084       return true;
   1085   }
   1086   return false;
   1087 }
   1088 
   1089 void ConsumedStateMap::checkParamsForReturnTypestate(SourceLocation BlameLoc,
   1090   ConsumedWarningsHandlerBase &WarningsHandler) const {
   1091 
   1092   for (const auto &DM : VarMap) {
   1093     if (isa<ParmVarDecl>(DM.first)) {
   1094       const auto *Param = cast<ParmVarDecl>(DM.first);
   1095       const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();
   1096 
   1097       if (!RTA)
   1098         continue;
   1099 
   1100       ConsumedState ExpectedState = mapReturnTypestateAttrState(RTA);
   1101       if (DM.second != ExpectedState)
   1102         WarningsHandler.warnParamReturnTypestateMismatch(BlameLoc,
   1103           Param->getNameAsString(), stateToString(ExpectedState),
   1104           stateToString(DM.second));
   1105     }
   1106   }
   1107 }
   1108 
   1109 void ConsumedStateMap::clearTemporaries() {
   1110   TmpMap.clear();
   1111 }
   1112 
   1113 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) const {
   1114   VarMapType::const_iterator Entry = VarMap.find(Var);
   1115 
   1116   if (Entry != VarMap.end())
   1117     return Entry->second;
   1118 
   1119   return CS_None;
   1120 }
   1121 
   1122 ConsumedState
   1123 ConsumedStateMap::getState(const CXXBindTemporaryExpr *Tmp) const {
   1124   TmpMapType::const_iterator Entry = TmpMap.find(Tmp);
   1125 
   1126   if (Entry != TmpMap.end())
   1127     return Entry->second;
   1128 
   1129   return CS_None;
   1130 }
   1131 
   1132 void ConsumedStateMap::intersect(const ConsumedStateMap &Other) {
   1133   ConsumedState LocalState;
   1134 
   1135   if (this->From && this->From == Other.From && !Other.Reachable) {
   1136     this->markUnreachable();
   1137     return;
   1138   }
   1139 
   1140   for (const auto &DM : Other.VarMap) {
   1141     LocalState = this->getState(DM.first);
   1142 
   1143     if (LocalState == CS_None)
   1144       continue;
   1145 
   1146     if (LocalState != DM.second)
   1147      VarMap[DM.first] = CS_Unknown;
   1148   }
   1149 }
   1150 
   1151 void ConsumedStateMap::intersectAtLoopHead(const CFGBlock *LoopHead,
   1152   const CFGBlock *LoopBack, const ConsumedStateMap *LoopBackStates,
   1153   ConsumedWarningsHandlerBase &WarningsHandler) {
   1154 
   1155   ConsumedState LocalState;
   1156   SourceLocation BlameLoc = getLastStmtLoc(LoopBack);
   1157 
   1158   for (const auto &DM : LoopBackStates->VarMap) {
   1159     LocalState = this->getState(DM.first);
   1160 
   1161     if (LocalState == CS_None)
   1162       continue;
   1163 
   1164     if (LocalState != DM.second) {
   1165       VarMap[DM.first] = CS_Unknown;
   1166       WarningsHandler.warnLoopStateMismatch(BlameLoc,
   1167                                             DM.first->getNameAsString());
   1168     }
   1169   }
   1170 }
   1171 
   1172 void ConsumedStateMap::markUnreachable() {
   1173   this->Reachable = false;
   1174   VarMap.clear();
   1175   TmpMap.clear();
   1176 }
   1177 
   1178 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
   1179   VarMap[Var] = State;
   1180 }
   1181 
   1182 void ConsumedStateMap::setState(const CXXBindTemporaryExpr *Tmp,
   1183                                 ConsumedState State) {
   1184   TmpMap[Tmp] = State;
   1185 }
   1186 
   1187 void ConsumedStateMap::remove(const CXXBindTemporaryExpr *Tmp) {
   1188   TmpMap.erase(Tmp);
   1189 }
   1190 
   1191 bool ConsumedStateMap::operator!=(const ConsumedStateMap *Other) const {
   1192   for (const auto &DM : Other->VarMap)
   1193     if (this->getState(DM.first) != DM.second)
   1194       return true;
   1195   return false;
   1196 }
   1197 
   1198 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
   1199                                                     const FunctionDecl *D) {
   1200   QualType ReturnType;
   1201   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
   1202     ReturnType = Constructor->getThisType()->getPointeeType();
   1203   } else
   1204     ReturnType = D->getCallResultType();
   1205 
   1206   if (const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>()) {
   1207     const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
   1208     if (!RD || !RD->hasAttr<ConsumableAttr>()) {
   1209       // FIXME: This should be removed when template instantiation propagates
   1210       //        attributes at template specialization definition, not
   1211       //        declaration. When it is removed the test needs to be enabled
   1212       //        in SemaDeclAttr.cpp.
   1213       WarningsHandler.warnReturnTypestateForUnconsumableType(
   1214           RTSAttr->getLocation(), ReturnType.getAsString());
   1215       ExpectedReturnState = CS_None;
   1216     } else
   1217       ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
   1218   } else if (isConsumableType(ReturnType)) {
   1219     if (isAutoCastType(ReturnType))   // We can auto-cast the state to the
   1220       ExpectedReturnState = CS_None;  // expected state.
   1221     else
   1222       ExpectedReturnState = mapConsumableAttrState(ReturnType);
   1223   }
   1224   else
   1225     ExpectedReturnState = CS_None;
   1226 }
   1227 
   1228 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
   1229                                   const ConsumedStmtVisitor &Visitor) {
   1230   std::unique_ptr<ConsumedStateMap> FalseStates(
   1231       new ConsumedStateMap(*CurrStates));
   1232   PropagationInfo PInfo;
   1233 
   1234   if (const auto *IfNode =
   1235           dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
   1236     const Expr *Cond = IfNode->getCond();
   1237 
   1238     PInfo = Visitor.getInfo(Cond);
   1239     if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
   1240       PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
   1241 
   1242     if (PInfo.isVarTest()) {
   1243       CurrStates->setSource(Cond);
   1244       FalseStates->setSource(Cond);
   1245       splitVarStateForIf(IfNode, PInfo.getVarTest(), CurrStates.get(),
   1246                          FalseStates.get());
   1247     } else if (PInfo.isBinTest()) {
   1248       CurrStates->setSource(PInfo.testSourceNode());
   1249       FalseStates->setSource(PInfo.testSourceNode());
   1250       splitVarStateForIfBinOp(PInfo, CurrStates.get(), FalseStates.get());
   1251     } else {
   1252       return false;
   1253     }
   1254   } else if (const auto *BinOp =
   1255        dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
   1256     PInfo = Visitor.getInfo(BinOp->getLHS());
   1257     if (!PInfo.isVarTest()) {
   1258       if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
   1259         PInfo = Visitor.getInfo(BinOp->getRHS());
   1260 
   1261         if (!PInfo.isVarTest())
   1262           return false;
   1263       } else {
   1264         return false;
   1265       }
   1266     }
   1267 
   1268     CurrStates->setSource(BinOp);
   1269     FalseStates->setSource(BinOp);
   1270 
   1271     const VarTestResult &Test = PInfo.getVarTest();
   1272     ConsumedState VarState = CurrStates->getState(Test.Var);
   1273 
   1274     if (BinOp->getOpcode() == BO_LAnd) {
   1275       if (VarState == CS_Unknown)
   1276         CurrStates->setState(Test.Var, Test.TestsFor);
   1277       else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
   1278         CurrStates->markUnreachable();
   1279 
   1280     } else if (BinOp->getOpcode() == BO_LOr) {
   1281       if (VarState == CS_Unknown)
   1282         FalseStates->setState(Test.Var,
   1283                               invertConsumedUnconsumed(Test.TestsFor));
   1284       else if (VarState == Test.TestsFor)
   1285         FalseStates->markUnreachable();
   1286     }
   1287   } else {
   1288     return false;
   1289   }
   1290 
   1291   CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
   1292 
   1293   if (*SI)
   1294     BlockInfo.addInfo(*SI, std::move(CurrStates));
   1295   else
   1296     CurrStates = nullptr;
   1297 
   1298   if (*++SI)
   1299     BlockInfo.addInfo(*SI, std::move(FalseStates));
   1300 
   1301   return true;
   1302 }
   1303 
   1304 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
   1305   const auto *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
   1306   if (!D)
   1307     return;
   1308 
   1309   CFG *CFGraph = AC.getCFG();
   1310   if (!CFGraph)
   1311     return;
   1312 
   1313   determineExpectedReturnState(AC, D);
   1314 
   1315   PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
   1316   // AC.getCFG()->viewCFG(LangOptions());
   1317 
   1318   BlockInfo = ConsumedBlockInfo(CFGraph->getNumBlockIDs(), SortedGraph);
   1319 
   1320   CurrStates = std::make_unique<ConsumedStateMap>();
   1321   ConsumedStmtVisitor Visitor(*this, CurrStates.get());
   1322 
   1323   // Add all trackable parameters to the state map.
   1324   for (const auto *PI : D->parameters())
   1325     Visitor.VisitParmVarDecl(PI);
   1326 
   1327   // Visit all of the function's basic blocks.
   1328   for (const auto *CurrBlock : *SortedGraph) {
   1329     if (!CurrStates)
   1330       CurrStates = BlockInfo.getInfo(CurrBlock);
   1331 
   1332     if (!CurrStates) {
   1333       continue;
   1334     } else if (!CurrStates->isReachable()) {
   1335       CurrStates = nullptr;
   1336       continue;
   1337     }
   1338 
   1339     Visitor.reset(CurrStates.get());
   1340 
   1341     // Visit all of the basic block's statements.
   1342     for (const auto &B : *CurrBlock) {
   1343       switch (B.getKind()) {
   1344       case CFGElement::Statement:
   1345         Visitor.Visit(B.castAs<CFGStmt>().getStmt());
   1346         break;
   1347 
   1348       case CFGElement::TemporaryDtor: {
   1349         const CFGTemporaryDtor &DTor = B.castAs<CFGTemporaryDtor>();
   1350         const CXXBindTemporaryExpr *BTE = DTor.getBindTemporaryExpr();
   1351 
   1352         Visitor.checkCallability(PropagationInfo(BTE),
   1353                                  DTor.getDestructorDecl(AC.getASTContext()),
   1354                                  BTE->getExprLoc());
   1355         CurrStates->remove(BTE);
   1356         break;
   1357       }
   1358 
   1359       case CFGElement::AutomaticObjectDtor: {
   1360         const CFGAutomaticObjDtor &DTor = B.castAs<CFGAutomaticObjDtor>();
   1361         SourceLocation Loc = DTor.getTriggerStmt()->getEndLoc();
   1362         const VarDecl *Var = DTor.getVarDecl();
   1363 
   1364         Visitor.checkCallability(PropagationInfo(Var),
   1365                                  DTor.getDestructorDecl(AC.getASTContext()),
   1366                                  Loc);
   1367         break;
   1368       }
   1369 
   1370       default:
   1371         break;
   1372       }
   1373     }
   1374 
   1375     // TODO: Handle other forms of branching with precision, including while-
   1376     //       and for-loops. (Deferred)
   1377     if (!splitState(CurrBlock, Visitor)) {
   1378       CurrStates->setSource(nullptr);
   1379 
   1380       if (CurrBlock->succ_size() > 1 ||
   1381           (CurrBlock->succ_size() == 1 &&
   1382            (*CurrBlock->succ_begin())->pred_size() > 1)) {
   1383 
   1384         auto *RawState = CurrStates.get();
   1385 
   1386         for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
   1387              SE = CurrBlock->succ_end(); SI != SE; ++SI) {
   1388           if (*SI == nullptr) continue;
   1389 
   1390           if (BlockInfo.isBackEdge(CurrBlock, *SI)) {
   1391             BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(
   1392                 *SI, CurrBlock, RawState, WarningsHandler);
   1393 
   1394             if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))
   1395               BlockInfo.discardInfo(*SI);
   1396           } else {
   1397             BlockInfo.addInfo(*SI, RawState, CurrStates);
   1398           }
   1399         }
   1400 
   1401         CurrStates = nullptr;
   1402       }
   1403     }
   1404 
   1405     if (CurrBlock == &AC.getCFG()->getExit() &&
   1406         D->getCallResultType()->isVoidType())
   1407       CurrStates->checkParamsForReturnTypestate(D->getLocation(),
   1408                                                 WarningsHandler);
   1409   } // End of block iterator.
   1410 
   1411   // Delete the last existing state map.
   1412   CurrStates = nullptr;
   1413 
   1414   WarningsHandler.emitDiagnostics();
   1415 }
   1416