Home | History | Annotate | Line # | Download | only in Instrumentation
      1 //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file implements a Union-find algorithm to compute Minimum Spanning Tree
     10 // for a given CFG.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
     15 #define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
     16 
     17 #include "llvm/ADT/DenseMap.h"
     18 #include "llvm/ADT/STLExtras.h"
     19 #include "llvm/Analysis/BlockFrequencyInfo.h"
     20 #include "llvm/Analysis/BranchProbabilityInfo.h"
     21 #include "llvm/Analysis/CFG.h"
     22 #include "llvm/Support/BranchProbability.h"
     23 #include "llvm/Support/Debug.h"
     24 #include "llvm/Support/raw_ostream.h"
     25 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     26 #include <utility>
     27 #include <vector>
     28 
     29 #define DEBUG_TYPE "cfgmst"
     30 
     31 using namespace llvm;
     32 
     33 namespace llvm {
     34 
     35 /// An union-find based Minimum Spanning Tree for CFG
     36 ///
     37 /// Implements a Union-find algorithm to compute Minimum Spanning Tree
     38 /// for a given CFG.
     39 template <class Edge, class BBInfo> class CFGMST {
     40 public:
     41   Function &F;
     42 
     43   // Store all the edges in CFG. It may contain some stale edges
     44   // when Removed is set.
     45   std::vector<std::unique_ptr<Edge>> AllEdges;
     46 
     47   // This map records the auxiliary information for each BB.
     48   DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
     49 
     50   // Whehter the function has an exit block with no successors.
     51   // (For function with an infinite loop, this block may be absent)
     52   bool ExitBlockFound = false;
     53 
     54   // Find the root group of the G and compress the path from G to the root.
     55   BBInfo *findAndCompressGroup(BBInfo *G) {
     56     if (G->Group != G)
     57       G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
     58     return static_cast<BBInfo *>(G->Group);
     59   }
     60 
     61   // Union BB1 and BB2 into the same group and return true.
     62   // Returns false if BB1 and BB2 are already in the same group.
     63   bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
     64     BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
     65     BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
     66 
     67     if (BB1G == BB2G)
     68       return false;
     69 
     70     // Make the smaller rank tree a direct child or the root of high rank tree.
     71     if (BB1G->Rank < BB2G->Rank)
     72       BB1G->Group = BB2G;
     73     else {
     74       BB2G->Group = BB1G;
     75       // If the ranks are the same, increment root of one tree by one.
     76       if (BB1G->Rank == BB2G->Rank)
     77         BB1G->Rank++;
     78     }
     79     return true;
     80   }
     81 
     82   // Give BB, return the auxiliary information.
     83   BBInfo &getBBInfo(const BasicBlock *BB) const {
     84     auto It = BBInfos.find(BB);
     85     assert(It->second.get() != nullptr);
     86     return *It->second.get();
     87   }
     88 
     89   // Give BB, return the auxiliary information if it's available.
     90   BBInfo *findBBInfo(const BasicBlock *BB) const {
     91     auto It = BBInfos.find(BB);
     92     if (It == BBInfos.end())
     93       return nullptr;
     94     return It->second.get();
     95   }
     96 
     97   // Traverse the CFG using a stack. Find all the edges and assign the weight.
     98   // Edges with large weight will be put into MST first so they are less likely
     99   // to be instrumented.
    100   void buildEdges() {
    101     LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
    102 
    103     const BasicBlock *Entry = &(F.getEntryBlock());
    104     uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
    105     // If we want to instrument the entry count, lower the weight to 0.
    106     if (InstrumentFuncEntry)
    107       EntryWeight = 0;
    108     Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr,
    109          *ExitOutgoing = nullptr, *ExitIncoming = nullptr;
    110     uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0;
    111 
    112     // Add a fake edge to the entry.
    113     EntryIncoming = &addEdge(nullptr, Entry, EntryWeight);
    114     LLVM_DEBUG(dbgs() << "  Edge: from fake node to " << Entry->getName()
    115                       << " w = " << EntryWeight << "\n");
    116 
    117     // Special handling for single BB functions.
    118     if (succ_empty(Entry)) {
    119       addEdge(Entry, nullptr, EntryWeight);
    120       return;
    121     }
    122 
    123     static const uint32_t CriticalEdgeMultiplier = 1000;
    124 
    125     for (BasicBlock &BB : F) {
    126       Instruction *TI = BB.getTerminator();
    127       uint64_t BBWeight =
    128           (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2);
    129       uint64_t Weight = 2;
    130       if (int successors = TI->getNumSuccessors()) {
    131         for (int i = 0; i != successors; ++i) {
    132           BasicBlock *TargetBB = TI->getSuccessor(i);
    133           bool Critical = isCriticalEdge(TI, i);
    134           uint64_t scaleFactor = BBWeight;
    135           if (Critical) {
    136             if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
    137               scaleFactor *= CriticalEdgeMultiplier;
    138             else
    139               scaleFactor = UINT64_MAX;
    140           }
    141           if (BPI != nullptr)
    142             Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
    143           if (Weight == 0)
    144             Weight++;
    145           auto *E = &addEdge(&BB, TargetBB, Weight);
    146           E->IsCritical = Critical;
    147           LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to "
    148                             << TargetBB->getName() << "  w=" << Weight << "\n");
    149 
    150           // Keep track of entry/exit edges:
    151           if (&BB == Entry) {
    152             if (Weight > MaxEntryOutWeight) {
    153               MaxEntryOutWeight = Weight;
    154               EntryOutgoing = E;
    155             }
    156           }
    157 
    158           auto *TargetTI = TargetBB->getTerminator();
    159           if (TargetTI && !TargetTI->getNumSuccessors()) {
    160             if (Weight > MaxExitInWeight) {
    161               MaxExitInWeight = Weight;
    162               ExitIncoming = E;
    163             }
    164           }
    165         }
    166       } else {
    167         ExitBlockFound = true;
    168         Edge *ExitO = &addEdge(&BB, nullptr, BBWeight);
    169         if (BBWeight > MaxExitOutWeight) {
    170           MaxExitOutWeight = BBWeight;
    171           ExitOutgoing = ExitO;
    172         }
    173         LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to fake exit"
    174                           << " w = " << BBWeight << "\n");
    175       }
    176     }
    177 
    178     // Entry/exit edge adjustment heurisitic:
    179     // prefer instrumenting entry edge over exit edge
    180     // if possible. Those exit edges may never have a chance to be
    181     // executed (for instance the program is an event handling loop)
    182     // before the profile is asynchronously dumped.
    183     //
    184     // If EntryIncoming and ExitOutgoing has similar weight, make sure
    185     // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
    186     // and ExitIncoming has similar weight, make sure ExitIncoming becomes
    187     // the min-edge.
    188     uint64_t EntryInWeight = EntryWeight;
    189 
    190     if (EntryInWeight >= MaxExitOutWeight &&
    191         EntryInWeight * 2 < MaxExitOutWeight * 3) {
    192       EntryIncoming->Weight = MaxExitOutWeight;
    193       ExitOutgoing->Weight = EntryInWeight + 1;
    194     }
    195 
    196     if (MaxEntryOutWeight >= MaxExitInWeight &&
    197         MaxEntryOutWeight * 2 < MaxExitInWeight * 3) {
    198       EntryOutgoing->Weight = MaxExitInWeight;
    199       ExitIncoming->Weight = MaxEntryOutWeight + 1;
    200     }
    201   }
    202 
    203   // Sort CFG edges based on its weight.
    204   void sortEdgesByWeight() {
    205     llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1,
    206                                    const std::unique_ptr<Edge> &Edge2) {
    207       return Edge1->Weight > Edge2->Weight;
    208     });
    209   }
    210 
    211   // Traverse all the edges and compute the Minimum Weight Spanning Tree
    212   // using union-find algorithm.
    213   void computeMinimumSpanningTree() {
    214     // First, put all the critical edge with landing-pad as the Dest to MST.
    215     // This works around the insufficient support of critical edges split
    216     // when destination BB is a landing pad.
    217     for (auto &Ei : AllEdges) {
    218       if (Ei->Removed)
    219         continue;
    220       if (Ei->IsCritical) {
    221         if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
    222           if (unionGroups(Ei->SrcBB, Ei->DestBB))
    223             Ei->InMST = true;
    224         }
    225       }
    226     }
    227 
    228     for (auto &Ei : AllEdges) {
    229       if (Ei->Removed)
    230         continue;
    231       // If we detect infinite loops, force
    232       // instrumenting the entry edge:
    233       if (!ExitBlockFound && Ei->SrcBB == nullptr)
    234         continue;
    235       if (unionGroups(Ei->SrcBB, Ei->DestBB))
    236         Ei->InMST = true;
    237     }
    238   }
    239 
    240   // Dump the Debug information about the instrumentation.
    241   void dumpEdges(raw_ostream &OS, const Twine &Message) const {
    242     if (!Message.str().empty())
    243       OS << Message << "\n";
    244     OS << "  Number of Basic Blocks: " << BBInfos.size() << "\n";
    245     for (auto &BI : BBInfos) {
    246       const BasicBlock *BB = BI.first;
    247       OS << "  BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << "  "
    248          << BI.second->infoString() << "\n";
    249     }
    250 
    251     OS << "  Number of Edges: " << AllEdges.size()
    252        << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
    253     uint32_t Count = 0;
    254     for (auto &EI : AllEdges)
    255       OS << "  Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
    256          << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
    257   }
    258 
    259   // Add an edge to AllEdges with weight W.
    260   Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) {
    261     uint32_t Index = BBInfos.size();
    262     auto Iter = BBInfos.end();
    263     bool Inserted;
    264     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
    265     if (Inserted) {
    266       // Newly inserted, update the real info.
    267       Iter->second = std::move(std::make_unique<BBInfo>(Index));
    268       Index++;
    269     }
    270     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
    271     if (Inserted)
    272       // Newly inserted, update the real info.
    273       Iter->second = std::move(std::make_unique<BBInfo>(Index));
    274     AllEdges.emplace_back(new Edge(Src, Dest, W));
    275     return *AllEdges.back();
    276   }
    277 
    278   BranchProbabilityInfo *BPI;
    279   BlockFrequencyInfo *BFI;
    280 
    281   // If function entry will be always instrumented.
    282   bool InstrumentFuncEntry;
    283 
    284 public:
    285   CFGMST(Function &Func, bool InstrumentFuncEntry_,
    286          BranchProbabilityInfo *BPI_ = nullptr,
    287          BlockFrequencyInfo *BFI_ = nullptr)
    288       : F(Func), BPI(BPI_), BFI(BFI_),
    289         InstrumentFuncEntry(InstrumentFuncEntry_) {
    290     buildEdges();
    291     sortEdgesByWeight();
    292     computeMinimumSpanningTree();
    293     if (AllEdges.size() > 1 && InstrumentFuncEntry)
    294       std::iter_swap(std::move(AllEdges.begin()),
    295                      std::move(AllEdges.begin() + AllEdges.size() - 1));
    296   }
    297 };
    298 
    299 } // end namespace llvm
    300 
    301 #undef DEBUG_TYPE // "cfgmst"
    302 
    303 #endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
    304