Home | History | Annotate | Line # | Download | only in Core
      1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 ///
      9 /// This file contains functions which are used to decide if a loop worth to be
     10 /// unrolled. Moreover, these functions manages the stack of loop which is
     11 /// tracked by the ProgramState.
     12 ///
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "clang/ASTMatchers/ASTMatchers.h"
     16 #include "clang/ASTMatchers/ASTMatchFinder.h"
     17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
     18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
     19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
     20 
     21 using namespace clang;
     22 using namespace ento;
     23 using namespace clang::ast_matchers;
     24 
     25 static const int MAXIMUM_STEP_UNROLLED = 128;
     26 
     27 struct LoopState {
     28 private:
     29   enum Kind { Normal, Unrolled } K;
     30   const Stmt *LoopStmt;
     31   const LocationContext *LCtx;
     32   unsigned maxStep;
     33   LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
     34       : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
     35 
     36 public:
     37   static LoopState getNormal(const Stmt *S, const LocationContext *L,
     38                              unsigned N) {
     39     return LoopState(Normal, S, L, N);
     40   }
     41   static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
     42                                unsigned N) {
     43     return LoopState(Unrolled, S, L, N);
     44   }
     45   bool isUnrolled() const { return K == Unrolled; }
     46   unsigned getMaxStep() const { return maxStep; }
     47   const Stmt *getLoopStmt() const { return LoopStmt; }
     48   const LocationContext *getLocationContext() const { return LCtx; }
     49   bool operator==(const LoopState &X) const {
     50     return K == X.K && LoopStmt == X.LoopStmt;
     51   }
     52   void Profile(llvm::FoldingSetNodeID &ID) const {
     53     ID.AddInteger(K);
     54     ID.AddPointer(LoopStmt);
     55     ID.AddPointer(LCtx);
     56     ID.AddInteger(maxStep);
     57   }
     58 };
     59 
     60 // The tracked stack of loops. The stack indicates that which loops the
     61 // simulated element contained by. The loops are marked depending if we decided
     62 // to unroll them.
     63 // TODO: The loop stack should not need to be in the program state since it is
     64 // lexical in nature. Instead, the stack of loops should be tracked in the
     65 // LocationContext.
     66 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
     67 
     68 namespace clang {
     69 namespace ento {
     70 
     71 static bool isLoopStmt(const Stmt *S) {
     72   return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
     73 }
     74 
     75 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
     76   auto LS = State->get<LoopStack>();
     77   if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
     78     State = State->set<LoopStack>(LS.getTail());
     79   return State;
     80 }
     81 
     82 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
     83   return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"),
     84                               hasOperatorName("<="), hasOperatorName(">="),
     85                               hasOperatorName("!=")),
     86                         hasEitherOperand(ignoringParenImpCasts(declRefExpr(
     87                             to(varDecl(hasType(isInteger())).bind(BindName))))),
     88                         hasEitherOperand(ignoringParenImpCasts(
     89                             integerLiteral().bind("boundNum"))))
     90       .bind("conditionOperator");
     91 }
     92 
     93 static internal::Matcher<Stmt>
     94 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
     95   return anyOf(
     96       unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
     97                     hasUnaryOperand(ignoringParenImpCasts(
     98                         declRefExpr(to(varDecl(VarNodeMatcher)))))),
     99       binaryOperator(isAssignmentOperator(),
    100                      hasLHS(ignoringParenImpCasts(
    101                          declRefExpr(to(varDecl(VarNodeMatcher)))))));
    102 }
    103 
    104 static internal::Matcher<Stmt>
    105 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
    106   return callExpr(forEachArgumentWithParam(
    107       declRefExpr(to(varDecl(VarNodeMatcher))),
    108       parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
    109 }
    110 
    111 static internal::Matcher<Stmt>
    112 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
    113   return declStmt(hasDescendant(varDecl(
    114       allOf(hasType(referenceType()),
    115             hasInitializer(anyOf(
    116                 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
    117                 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
    118 }
    119 
    120 static internal::Matcher<Stmt>
    121 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
    122   return unaryOperator(
    123       hasOperatorName("&"),
    124       hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
    125 }
    126 
    127 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
    128   return hasDescendant(stmt(
    129       anyOf(gotoStmt(), switchStmt(), returnStmt(),
    130             // Escaping and not known mutation of the loop counter is handled
    131             // by exclusion of assigning and address-of operators and
    132             // pass-by-ref function calls on the loop counter from the body.
    133             changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
    134             callByRef(equalsBoundNode(std::string(NodeName))),
    135             getAddrTo(equalsBoundNode(std::string(NodeName))),
    136             assignedToRef(equalsBoundNode(std::string(NodeName))))));
    137 }
    138 
    139 static internal::Matcher<Stmt> forLoopMatcher() {
    140   return forStmt(
    141              hasCondition(simpleCondition("initVarName")),
    142              // Initialization should match the form: 'int i = 6' or 'i = 42'.
    143              hasLoopInit(
    144                  anyOf(declStmt(hasSingleDecl(
    145                            varDecl(allOf(hasInitializer(ignoringParenImpCasts(
    146                                              integerLiteral().bind("initNum"))),
    147                                          equalsBoundNode("initVarName"))))),
    148                        binaryOperator(hasLHS(declRefExpr(to(varDecl(
    149                                           equalsBoundNode("initVarName"))))),
    150                                       hasRHS(ignoringParenImpCasts(
    151                                           integerLiteral().bind("initNum")))))),
    152              // Incrementation should be a simple increment or decrement
    153              // operator call.
    154              hasIncrement(unaryOperator(
    155                  anyOf(hasOperatorName("++"), hasOperatorName("--")),
    156                  hasUnaryOperand(declRefExpr(
    157                      to(varDecl(allOf(equalsBoundNode("initVarName"),
    158                                       hasType(isInteger())))))))),
    159              unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
    160 }
    161 
    162 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
    163   // Global variables assumed as escaped variables.
    164   if (VD->hasGlobalStorage())
    165     return true;
    166 
    167   const bool isParm = isa<ParmVarDecl>(VD);
    168   // Reference parameters are assumed as escaped variables.
    169   if (isParm && VD->getType()->isReferenceType())
    170     return true;
    171 
    172   while (!N->pred_empty()) {
    173     // FIXME: getStmtForDiagnostics() does nasty things in order to provide
    174     // a valid statement for body farms, do we need this behavior here?
    175     const Stmt *S = N->getStmtForDiagnostics();
    176     if (!S) {
    177       N = N->getFirstPred();
    178       continue;
    179     }
    180 
    181     if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
    182       for (const Decl *D : DS->decls()) {
    183         // Once we reach the declaration of the VD we can return.
    184         if (D->getCanonicalDecl() == VD)
    185           return false;
    186       }
    187     }
    188     // Check the usage of the pass-by-ref function calls and adress-of operator
    189     // on VD and reference initialized by VD.
    190     ASTContext &ASTCtx =
    191         N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
    192     auto Match =
    193         match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
    194                          assignedToRef(equalsNode(VD)))),
    195               *S, ASTCtx);
    196     if (!Match.empty())
    197       return true;
    198 
    199     N = N->getFirstPred();
    200   }
    201 
    202   // Parameter declaration will not be found.
    203   if (isParm)
    204     return false;
    205 
    206   llvm_unreachable("Reached root without finding the declaration of VD");
    207 }
    208 
    209 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
    210                             ExplodedNode *Pred, unsigned &maxStep) {
    211 
    212   if (!isLoopStmt(LoopStmt))
    213     return false;
    214 
    215   // TODO: Match the cases where the bound is not a concrete literal but an
    216   // integer with known value
    217   auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
    218   if (Matches.empty())
    219     return false;
    220 
    221   auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
    222   llvm::APInt BoundNum =
    223       Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
    224   llvm::APInt InitNum =
    225       Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
    226   auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
    227   if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
    228     InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
    229     BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
    230   }
    231 
    232   if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
    233     maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
    234   else
    235     maxStep = (BoundNum - InitNum).abs().getZExtValue();
    236 
    237   // Check if the counter of the loop is not escaped before.
    238   return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
    239 }
    240 
    241 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
    242   const Stmt *S = nullptr;
    243   while (!N->pred_empty()) {
    244     if (N->succ_size() > 1)
    245       return true;
    246 
    247     ProgramPoint P = N->getLocation();
    248     if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
    249       S = BE->getBlock()->getTerminatorStmt();
    250 
    251     if (S == LoopStmt)
    252       return false;
    253 
    254     N = N->getFirstPred();
    255   }
    256 
    257   llvm_unreachable("Reached root without encountering the previous step");
    258 }
    259 
    260 // updateLoopStack is called on every basic block, therefore it needs to be fast
    261 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
    262                                 ExplodedNode *Pred, unsigned maxVisitOnPath) {
    263   auto State = Pred->getState();
    264   auto LCtx = Pred->getLocationContext();
    265 
    266   if (!isLoopStmt(LoopStmt))
    267     return State;
    268 
    269   auto LS = State->get<LoopStack>();
    270   if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
    271       LCtx == LS.getHead().getLocationContext()) {
    272     if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
    273       State = State->set<LoopStack>(LS.getTail());
    274       State = State->add<LoopStack>(
    275           LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
    276     }
    277     return State;
    278   }
    279   unsigned maxStep;
    280   if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
    281     State = State->add<LoopStack>(
    282         LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
    283     return State;
    284   }
    285 
    286   unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
    287 
    288   unsigned innerMaxStep = maxStep * outerStep;
    289   if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
    290     State = State->add<LoopStack>(
    291         LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
    292   else
    293     State = State->add<LoopStack>(
    294         LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
    295   return State;
    296 }
    297 
    298 bool isUnrolledState(ProgramStateRef State) {
    299   auto LS = State->get<LoopStack>();
    300   if (LS.isEmpty() || !LS.getHead().isUnrolled())
    301     return false;
    302   return true;
    303 }
    304 }
    305 }
    306