Home | History | Annotate | Line # | Download | only in Analysis
      1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file implements an algorithm that returns for a divergent branch
     10 // the set of basic blocks whose phi nodes become divergent due to divergent
     11 // control. These are the blocks that are reachable by two disjoint paths from
     12 // the branch or loop exits that have a reaching path that is disjoint from a
     13 // path to the loop latch.
     14 //
     15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
     16 // control-induced divergence in phi nodes.
     17 //
     18 // -- Summary --
     19 // The SyncDependenceAnalysis lazily computes sync dependences [3].
     20 // The analysis evaluates the disjoint path criterion [2] by a reduction
     21 // to SSA construction. The SSA construction algorithm is implemented as
     22 // a simple data-flow analysis [1].
     23 //
     24 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
     25 // [2] "Efficiently Computing Static Single Assignment Form
     26 //     and the Control Dependence Graph", TOPLAS '91,
     27 //           Cytron, Ferrante, Rosen, Wegman and Zadeck
     28 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
     29 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
     30 //
     31 // -- Sync dependence --
     32 // Sync dependence [4] characterizes the control flow aspect of the
     33 // propagation of branch divergence. For example,
     34 //
     35 //   %cond = icmp slt i32 %tid, 10
     36 //   br i1 %cond, label %then, label %else
     37 // then:
     38 //   br label %merge
     39 // else:
     40 //   br label %merge
     41 // merge:
     42 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
     43 //
     44 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
     45 // because %tid is not on its use-def chains, %a is sync dependent on %tid
     46 // because the branch "br i1 %cond" depends on %tid and affects which value %a
     47 // is assigned to.
     48 //
     49 // -- Reduction to SSA construction --
     50 // There are two disjoint paths from A to X, if a certain variant of SSA
     51 // construction places a phi node in X under the following set-up scheme [2].
     52 //
     53 // This variant of SSA construction ignores incoming undef values.
     54 // That is paths from the entry without a definition do not result in
     55 // phi nodes.
     56 //
     57 //       entry
     58 //     /      \
     59 //    A        \
     60 //  /   \       Y
     61 // B     C     /
     62 //  \   /  \  /
     63 //    D     E
     64 //     \   /
     65 //       F
     66 // Assume that A contains a divergent branch. We are interested
     67 // in the set of all blocks where each block is reachable from A
     68 // via two disjoint paths. This would be the set {D, F} in this
     69 // case.
     70 // To generally reduce this query to SSA construction we introduce
     71 // a virtual variable x and assign to x different values in each
     72 // successor block of A.
     73 //           entry
     74 //         /      \
     75 //        A        \
     76 //      /   \       Y
     77 // x = 0   x = 1   /
     78 //      \  /   \  /
     79 //        D     E
     80 //         \   /
     81 //           F
     82 // Our flavor of SSA construction for x will construct the following
     83 //            entry
     84 //          /      \
     85 //         A        \
     86 //       /   \       Y
     87 // x0 = 0   x1 = 1  /
     88 //       \   /   \ /
     89 //      x2=phi    E
     90 //         \     /
     91 //          x3=phi
     92 // The blocks D and F contain phi nodes and are thus each reachable
     93 // by two disjoins paths from A.
     94 //
     95 // -- Remarks --
     96 // In case of loop exits we need to check the disjoint path criterion for loops
     97 // [2]. To this end, we check whether the definition of x differs between the
     98 // loop exit and the loop header (_after_ SSA construction).
     99 //
    100 //===----------------------------------------------------------------------===//
    101 #include "llvm/Analysis/SyncDependenceAnalysis.h"
    102 #include "llvm/ADT/PostOrderIterator.h"
    103 #include "llvm/ADT/SmallPtrSet.h"
    104 #include "llvm/Analysis/PostDominators.h"
    105 #include "llvm/IR/BasicBlock.h"
    106 #include "llvm/IR/CFG.h"
    107 #include "llvm/IR/Dominators.h"
    108 #include "llvm/IR/Function.h"
    109 
    110 #include <functional>
    111 #include <stack>
    112 #include <unordered_set>
    113 
    114 #define DEBUG_TYPE "sync-dependence"
    115 
    116 // The SDA algorithm operates on a modified CFG - we modify the edges leaving
    117 // loop headers as follows:
    118 //
    119 // * We remove all edges leaving all loop headers.
    120 // * We add additional edges from the loop headers to their exit blocks.
    121 //
    122 // The modification is virtual, that is whenever we visit a loop header we
    123 // pretend it had different successors.
    124 namespace {
    125 using namespace llvm;
    126 
    127 // Custom Post-Order Traveral
    128 //
    129 // We cannot use the vanilla (R)PO computation of LLVM because:
    130 // * We (virtually) modify the CFG.
    131 // * We want a loop-compact block enumeration, that is the numbers assigned by
    132 //   the traveral to the blocks of a loop are an interval.
    133 using POCB = std::function<void(const BasicBlock &)>;
    134 using VisitedSet = std::set<const BasicBlock *>;
    135 using BlockStack = std::vector<const BasicBlock *>;
    136 
    137 // forward
    138 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
    139                           VisitedSet &Finalized);
    140 
    141 // for a nested region (top-level loop or nested loop)
    142 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
    143                            POCB CallBack, VisitedSet &Finalized) {
    144   const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
    145   while (!Stack.empty()) {
    146     const auto *NextBB = Stack.back();
    147 
    148     auto *NestedLoop = LI.getLoopFor(NextBB);
    149     bool IsNestedLoop = NestedLoop != Loop;
    150 
    151     // Treat the loop as a node
    152     if (IsNestedLoop) {
    153       SmallVector<BasicBlock *, 3> NestedExits;
    154       NestedLoop->getUniqueExitBlocks(NestedExits);
    155       bool PushedNodes = false;
    156       for (const auto *NestedExitBB : NestedExits) {
    157         if (NestedExitBB == LoopHeader)
    158           continue;
    159         if (Loop && !Loop->contains(NestedExitBB))
    160           continue;
    161         if (Finalized.count(NestedExitBB))
    162           continue;
    163         PushedNodes = true;
    164         Stack.push_back(NestedExitBB);
    165       }
    166       if (!PushedNodes) {
    167         // All loop exits finalized -> finish this node
    168         Stack.pop_back();
    169         computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
    170       }
    171       continue;
    172     }
    173 
    174     // DAG-style
    175     bool PushedNodes = false;
    176     for (const auto *SuccBB : successors(NextBB)) {
    177       if (SuccBB == LoopHeader)
    178         continue;
    179       if (Loop && !Loop->contains(SuccBB))
    180         continue;
    181       if (Finalized.count(SuccBB))
    182         continue;
    183       PushedNodes = true;
    184       Stack.push_back(SuccBB);
    185     }
    186     if (!PushedNodes) {
    187       // Never push nodes twice
    188       Stack.pop_back();
    189       if (!Finalized.insert(NextBB).second)
    190         continue;
    191       CallBack(*NextBB);
    192     }
    193   }
    194 }
    195 
    196 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
    197   VisitedSet Finalized;
    198   BlockStack Stack;
    199   Stack.reserve(24); // FIXME made-up number
    200   Stack.push_back(&F.getEntryBlock());
    201   computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
    202 }
    203 
    204 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
    205                           VisitedSet &Finalized) {
    206   /// Call CallBack on all loop blocks.
    207   std::vector<const BasicBlock *> Stack;
    208   const auto *LoopHeader = Loop.getHeader();
    209 
    210   // Visit the header last
    211   Finalized.insert(LoopHeader);
    212   CallBack(*LoopHeader);
    213 
    214   // Initialize with immediate successors
    215   for (const auto *BB : successors(LoopHeader)) {
    216     if (!Loop.contains(BB))
    217       continue;
    218     if (BB == LoopHeader)
    219       continue;
    220     Stack.push_back(BB);
    221   }
    222 
    223   // Compute PO inside region
    224   computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
    225 }
    226 
    227 } // namespace
    228 
    229 namespace llvm {
    230 
    231 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
    232 
    233 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
    234                                                const PostDominatorTree &PDT,
    235                                                const LoopInfo &LI)
    236     : DT(DT), PDT(PDT), LI(LI) {
    237   computeTopLevelPO(*DT.getRoot()->getParent(), LI,
    238                     [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
    239 }
    240 
    241 SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
    242 
    243 // divergence propagator for reducible CFGs
    244 struct DivergencePropagator {
    245   const ModifiedPO &LoopPOT;
    246   const DominatorTree &DT;
    247   const PostDominatorTree &PDT;
    248   const LoopInfo &LI;
    249   const BasicBlock &DivTermBlock;
    250 
    251   // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
    252   //   block B
    253   // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
    254   // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
    255   // from X or B is an immediate successor of X (initial value).
    256   using BlockLabelVec = std::vector<const BasicBlock *>;
    257   BlockLabelVec BlockLabels;
    258   // divergent join and loop exit descriptor.
    259   std::unique_ptr<ControlDivergenceDesc> DivDesc;
    260 
    261   DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
    262                        const PostDominatorTree &PDT, const LoopInfo &LI,
    263                        const BasicBlock &DivTermBlock)
    264       : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
    265         BlockLabels(LoopPOT.size(), nullptr),
    266         DivDesc(new ControlDivergenceDesc) {}
    267 
    268   void printDefs(raw_ostream &Out) {
    269     Out << "Propagator::BlockLabels {\n";
    270     for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
    271       const auto *Label = BlockLabels[BlockIdx];
    272       Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
    273           << ") : ";
    274       if (!Label) {
    275         Out << "<null>\n";
    276       } else {
    277         Out << Label->getName() << "\n";
    278       }
    279     }
    280     Out << "}\n";
    281   }
    282 
    283   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
    284   // causes a divergent join.
    285   bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
    286     auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
    287 
    288     // unset or same reaching label
    289     const auto *OldLabel = BlockLabels[SuccIdx];
    290     if (!OldLabel || (OldLabel == &PushedLabel)) {
    291       BlockLabels[SuccIdx] = &PushedLabel;
    292       return false;
    293     }
    294 
    295     // Update the definition
    296     BlockLabels[SuccIdx] = &SuccBlock;
    297     return true;
    298   }
    299 
    300   // visiting a virtual loop exit edge from the loop header --> temporal
    301   // divergence on join
    302   bool visitLoopExitEdge(const BasicBlock &ExitBlock,
    303                          const BasicBlock &DefBlock, bool FromParentLoop) {
    304     // Pushing from a non-parent loop cannot cause temporal divergence.
    305     if (!FromParentLoop)
    306       return visitEdge(ExitBlock, DefBlock);
    307 
    308     if (!computeJoin(ExitBlock, DefBlock))
    309       return false;
    310 
    311     // Identified a divergent loop exit
    312     DivDesc->LoopDivBlocks.insert(&ExitBlock);
    313     LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
    314                       << "\n");
    315     return true;
    316   }
    317 
    318   // process \p SuccBlock with reaching definition \p DefBlock
    319   bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
    320     if (!computeJoin(SuccBlock, DefBlock))
    321       return false;
    322 
    323     // Divergent, disjoint paths join.
    324     DivDesc->JoinDivBlocks.insert(&SuccBlock);
    325     LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
    326     return true;
    327   }
    328 
    329   std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
    330     assert(DivDesc);
    331 
    332     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
    333                       << "\n");
    334 
    335     const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
    336 
    337     // Early stopping criterion
    338     int FloorIdx = LoopPOT.size() - 1;
    339     const BasicBlock *FloorLabel = nullptr;
    340 
    341     // bootstrap with branch targets
    342     int BlockIdx = 0;
    343 
    344     for (const auto *SuccBlock : successors(&DivTermBlock)) {
    345       auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
    346       BlockLabels[SuccIdx] = SuccBlock;
    347 
    348       // Find the successor with the highest index to start with
    349       BlockIdx = std::max<int>(BlockIdx, SuccIdx);
    350       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
    351 
    352       // Identify immediate divergent loop exits
    353       if (!DivBlockLoop)
    354         continue;
    355 
    356       const auto *BlockLoop = LI.getLoopFor(SuccBlock);
    357       if (BlockLoop && DivBlockLoop->contains(BlockLoop))
    358         continue;
    359       DivDesc->LoopDivBlocks.insert(SuccBlock);
    360       LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
    361                         << SuccBlock->getName() << "\n");
    362     }
    363 
    364     // propagate definitions at the immediate successors of the node in RPO
    365     for (; BlockIdx >= FloorIdx; --BlockIdx) {
    366       LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
    367 
    368       // Any label available here
    369       const auto *Label = BlockLabels[BlockIdx];
    370       if (!Label)
    371         continue;
    372 
    373       // Ok. Get the block
    374       const auto *Block = LoopPOT.getBlockAt(BlockIdx);
    375       LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
    376 
    377       auto *BlockLoop = LI.getLoopFor(Block);
    378       bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
    379       bool CausedJoin = false;
    380       int LoweredFloorIdx = FloorIdx;
    381       if (IsLoopHeader) {
    382         // Disconnect from immediate successors and propagate directly to loop
    383         // exits.
    384         SmallVector<BasicBlock *, 4> BlockLoopExits;
    385         BlockLoop->getExitBlocks(BlockLoopExits);
    386 
    387         bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
    388         for (const auto *BlockLoopExit : BlockLoopExits) {
    389           CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
    390           LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
    391                                           LoopPOT.getIndexOf(*BlockLoopExit));
    392         }
    393       } else {
    394         // Acyclic successor case
    395         for (const auto *SuccBlock : successors(Block)) {
    396           CausedJoin |= visitEdge(*SuccBlock, *Label);
    397           LoweredFloorIdx =
    398               std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
    399         }
    400       }
    401 
    402       // Floor update
    403       if (CausedJoin) {
    404         // 1. Different labels pushed to successors
    405         FloorIdx = LoweredFloorIdx;
    406       } else if (FloorLabel != Label) {
    407         // 2. No join caused BUT we pushed a label that is different than the
    408         // last pushed label
    409         FloorIdx = LoweredFloorIdx;
    410         FloorLabel = Label;
    411       }
    412     }
    413 
    414     LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
    415 
    416     return std::move(DivDesc);
    417   }
    418 };
    419 
    420 #ifndef NDEBUG
    421 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
    422   Out << "[";
    423   ListSeparator LS;
    424   for (const auto *BB : Blocks)
    425     Out << LS << BB->getName();
    426   Out << "]";
    427 }
    428 #endif
    429 
    430 const ControlDivergenceDesc &
    431 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
    432   // trivial case
    433   if (Term.getNumSuccessors() <= 1) {
    434     return EmptyDivergenceDesc;
    435   }
    436 
    437   // already available in cache?
    438   auto ItCached = CachedControlDivDescs.find(&Term);
    439   if (ItCached != CachedControlDivDescs.end())
    440     return *ItCached->second;
    441 
    442   // compute all join points
    443   // Special handling of divergent loop exits is not needed for LCSSA
    444   const auto &TermBlock = *Term.getParent();
    445   DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
    446   auto DivDesc = Propagator.computeJoinPoints();
    447 
    448   LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
    449              dbgs() << "JoinDivBlocks: ";
    450              printBlockSet(DivDesc->JoinDivBlocks, dbgs());
    451              dbgs() << "\nLoopDivBlocks: ";
    452              printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
    453 
    454   auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
    455   assert(ItInserted.second);
    456   return *ItInserted.first->second;
    457 }
    458 
    459 } // namespace llvm
    460