Home | History | Annotate | Line # | Download | only in ARCMigrate
      1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // Adds brackets in case statements that "contain" initialization of retaining
     10 // variable, thus emitting the "switch case is in protected scope" error.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "Internals.h"
     15 #include "Transforms.h"
     16 #include "clang/AST/ASTContext.h"
     17 #include "clang/Basic/SourceManager.h"
     18 #include "clang/Sema/SemaDiagnostic.h"
     19 
     20 using namespace clang;
     21 using namespace arcmt;
     22 using namespace trans;
     23 
     24 namespace {
     25 
     26 class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
     27   SmallVectorImpl<DeclRefExpr *> &Refs;
     28 
     29 public:
     30   LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
     31     : Refs(refs) { }
     32 
     33   bool VisitDeclRefExpr(DeclRefExpr *E) {
     34     if (ValueDecl *D = E->getDecl())
     35       if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
     36         Refs.push_back(E);
     37     return true;
     38   }
     39 };
     40 
     41 struct CaseInfo {
     42   SwitchCase *SC;
     43   SourceRange Range;
     44   enum {
     45     St_Unchecked,
     46     St_CannotFix,
     47     St_Fixed
     48   } State;
     49 
     50   CaseInfo() : SC(nullptr), State(St_Unchecked) {}
     51   CaseInfo(SwitchCase *S, SourceRange Range)
     52     : SC(S), Range(Range), State(St_Unchecked) {}
     53 };
     54 
     55 class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
     56   ParentMap &PMap;
     57   SmallVectorImpl<CaseInfo> &Cases;
     58 
     59 public:
     60   CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
     61     : PMap(PMap), Cases(Cases) { }
     62 
     63   bool VisitSwitchStmt(SwitchStmt *S) {
     64     SwitchCase *Curr = S->getSwitchCaseList();
     65     if (!Curr)
     66       return true;
     67     Stmt *Parent = getCaseParent(Curr);
     68     Curr = Curr->getNextSwitchCase();
     69     // Make sure all case statements are in the same scope.
     70     while (Curr) {
     71       if (getCaseParent(Curr) != Parent)
     72         return true;
     73       Curr = Curr->getNextSwitchCase();
     74     }
     75 
     76     SourceLocation NextLoc = S->getEndLoc();
     77     Curr = S->getSwitchCaseList();
     78     // We iterate over case statements in reverse source-order.
     79     while (Curr) {
     80       Cases.push_back(
     81           CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc)));
     82       NextLoc = Curr->getBeginLoc();
     83       Curr = Curr->getNextSwitchCase();
     84     }
     85     return true;
     86   }
     87 
     88   Stmt *getCaseParent(SwitchCase *S) {
     89     Stmt *Parent = PMap.getParent(S);
     90     while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
     91       Parent = PMap.getParent(Parent);
     92     return Parent;
     93   }
     94 };
     95 
     96 class ProtectedScopeFixer {
     97   MigrationPass &Pass;
     98   SourceManager &SM;
     99   SmallVector<CaseInfo, 16> Cases;
    100   SmallVector<DeclRefExpr *, 16> LocalRefs;
    101 
    102 public:
    103   ProtectedScopeFixer(BodyContext &BodyCtx)
    104     : Pass(BodyCtx.getMigrationContext().Pass),
    105       SM(Pass.Ctx.getSourceManager()) {
    106 
    107     CaseCollector(BodyCtx.getParentMap(), Cases)
    108         .TraverseStmt(BodyCtx.getTopStmt());
    109     LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
    110 
    111     SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
    112     const CapturedDiagList &DiagList = Pass.getDiags();
    113     // Copy the diagnostics so we don't have to worry about invaliding iterators
    114     // from the diagnostic list.
    115     SmallVector<StoredDiagnostic, 16> StoredDiags;
    116     StoredDiags.append(DiagList.begin(), DiagList.end());
    117     SmallVectorImpl<StoredDiagnostic>::iterator
    118         I = StoredDiags.begin(), E = StoredDiags.end();
    119     while (I != E) {
    120       if (I->getID() == diag::err_switch_into_protected_scope &&
    121           isInRange(I->getLocation(), BodyRange)) {
    122         handleProtectedScopeError(I, E);
    123         continue;
    124       }
    125       ++I;
    126     }
    127   }
    128 
    129   void handleProtectedScopeError(
    130                              SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
    131                              SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
    132     Transaction Trans(Pass.TA);
    133     assert(DiagI->getID() == diag::err_switch_into_protected_scope);
    134     SourceLocation ErrLoc = DiagI->getLocation();
    135     bool handledAllNotes = true;
    136     ++DiagI;
    137     for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
    138          ++DiagI) {
    139       if (!handleProtectedNote(*DiagI))
    140         handledAllNotes = false;
    141     }
    142 
    143     if (handledAllNotes)
    144       Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
    145   }
    146 
    147   bool handleProtectedNote(const StoredDiagnostic &Diag) {
    148     assert(Diag.getLevel() == DiagnosticsEngine::Note);
    149 
    150     for (unsigned i = 0; i != Cases.size(); i++) {
    151       CaseInfo &info = Cases[i];
    152       if (isInRange(Diag.getLocation(), info.Range)) {
    153 
    154         if (info.State == CaseInfo::St_Unchecked)
    155           tryFixing(info);
    156         assert(info.State != CaseInfo::St_Unchecked);
    157 
    158         if (info.State == CaseInfo::St_Fixed) {
    159           Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
    160           return true;
    161         }
    162         return false;
    163       }
    164     }
    165 
    166     return false;
    167   }
    168 
    169   void tryFixing(CaseInfo &info) {
    170     assert(info.State == CaseInfo::St_Unchecked);
    171     if (hasVarReferencedOutside(info)) {
    172       info.State = CaseInfo::St_CannotFix;
    173       return;
    174     }
    175 
    176     Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
    177     Pass.TA.insert(info.Range.getEnd(), "}\n");
    178     info.State = CaseInfo::St_Fixed;
    179   }
    180 
    181   bool hasVarReferencedOutside(CaseInfo &info) {
    182     for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
    183       DeclRefExpr *DRE = LocalRefs[i];
    184       if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
    185           !isInRange(DRE->getLocation(), info.Range))
    186         return true;
    187     }
    188     return false;
    189   }
    190 
    191   bool isInRange(SourceLocation Loc, SourceRange R) {
    192     if (Loc.isInvalid())
    193       return false;
    194     return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
    195             SM.isBeforeInTranslationUnit(Loc, R.getEnd());
    196   }
    197 };
    198 
    199 } // anonymous namespace
    200 
    201 void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
    202   ProtectedScopeFixer Fix(BodyCtx);
    203 }
    204