Home | History | Annotate | Line # | Download | only in Analysis
      1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // Loops should be simplified before this analysis.
     10 //
     11 //===----------------------------------------------------------------------===//
     12 
     13 #include "llvm/Analysis/BranchProbabilityInfo.h"
     14 #include "llvm/ADT/PostOrderIterator.h"
     15 #include "llvm/ADT/SCCIterator.h"
     16 #include "llvm/ADT/STLExtras.h"
     17 #include "llvm/ADT/SmallVector.h"
     18 #include "llvm/Analysis/LoopInfo.h"
     19 #include "llvm/Analysis/PostDominators.h"
     20 #include "llvm/Analysis/TargetLibraryInfo.h"
     21 #include "llvm/IR/Attributes.h"
     22 #include "llvm/IR/BasicBlock.h"
     23 #include "llvm/IR/CFG.h"
     24 #include "llvm/IR/Constants.h"
     25 #include "llvm/IR/Dominators.h"
     26 #include "llvm/IR/Function.h"
     27 #include "llvm/IR/InstrTypes.h"
     28 #include "llvm/IR/Instruction.h"
     29 #include "llvm/IR/Instructions.h"
     30 #include "llvm/IR/LLVMContext.h"
     31 #include "llvm/IR/Metadata.h"
     32 #include "llvm/IR/PassManager.h"
     33 #include "llvm/IR/Type.h"
     34 #include "llvm/IR/Value.h"
     35 #include "llvm/InitializePasses.h"
     36 #include "llvm/Pass.h"
     37 #include "llvm/Support/BranchProbability.h"
     38 #include "llvm/Support/Casting.h"
     39 #include "llvm/Support/CommandLine.h"
     40 #include "llvm/Support/Debug.h"
     41 #include "llvm/Support/raw_ostream.h"
     42 #include <cassert>
     43 #include <cstdint>
     44 #include <iterator>
     45 #include <utility>
     46 
     47 using namespace llvm;
     48 
     49 #define DEBUG_TYPE "branch-prob"
     50 
     51 static cl::opt<bool> PrintBranchProb(
     52     "print-bpi", cl::init(false), cl::Hidden,
     53     cl::desc("Print the branch probability info."));
     54 
     55 cl::opt<std::string> PrintBranchProbFuncName(
     56     "print-bpi-func-name", cl::Hidden,
     57     cl::desc("The option to specify the name of the function "
     58              "whose branch probability info is printed."));
     59 
     60 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
     61                       "Branch Probability Analysis", false, true)
     62 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
     63 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
     64 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
     65 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
     66 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
     67                     "Branch Probability Analysis", false, true)
     68 
     69 BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
     70     : FunctionPass(ID) {
     71   initializeBranchProbabilityInfoWrapperPassPass(
     72       *PassRegistry::getPassRegistry());
     73 }
     74 
     75 char BranchProbabilityInfoWrapperPass::ID = 0;
     76 
     77 // Weights are for internal use only. They are used by heuristics to help to
     78 // estimate edges' probability. Example:
     79 //
     80 // Using "Loop Branch Heuristics" we predict weights of edges for the
     81 // block BB2.
     82 //         ...
     83 //          |
     84 //          V
     85 //         BB1<-+
     86 //          |   |
     87 //          |   | (Weight = 124)
     88 //          V   |
     89 //         BB2--+
     90 //          |
     91 //          | (Weight = 4)
     92 //          V
     93 //         BB3
     94 //
     95 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
     96 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
     97 static const uint32_t LBH_TAKEN_WEIGHT = 124;
     98 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
     99 
    100 /// Unreachable-terminating branch taken probability.
    101 ///
    102 /// This is the probability for a branch being taken to a block that terminates
    103 /// (eventually) in unreachable. These are predicted as unlikely as possible.
    104 /// All reachable probability will proportionally share the remaining part.
    105 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
    106 
    107 static const uint32_t PH_TAKEN_WEIGHT = 20;
    108 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
    109 
    110 static const uint32_t ZH_TAKEN_WEIGHT = 20;
    111 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
    112 
    113 static const uint32_t FPH_TAKEN_WEIGHT = 20;
    114 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
    115 
    116 /// This is the probability for an ordered floating point comparison.
    117 static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
    118 /// This is the probability for an unordered floating point comparison, it means
    119 /// one or two of the operands are NaN. Usually it is used to test for an
    120 /// exceptional case, so the result is unlikely.
    121 static const uint32_t FPH_UNO_WEIGHT = 1;
    122 
    123 /// Set of dedicated "absolute" execution weights for a block. These weights are
    124 /// meaningful relative to each other and their derivatives only.
    125 enum class BlockExecWeight : std::uint32_t {
    126   /// Special weight used for cases with exact zero probability.
    127   ZERO = 0x0,
    128   /// Minimal possible non zero weight.
    129   LOWEST_NON_ZERO = 0x1,
    130   /// Weight to an 'unreachable' block.
    131   UNREACHABLE = ZERO,
    132   /// Weight to a block containing non returning call.
    133   NORETURN = LOWEST_NON_ZERO,
    134   /// Weight to 'unwind' block of an invoke instruction.
    135   UNWIND = LOWEST_NON_ZERO,
    136   /// Weight to a 'cold' block. Cold blocks are the ones containing calls marked
    137   /// with attribute 'cold'.
    138   COLD = 0xffff,
    139   /// Default weight is used in cases when there is no dedicated execution
    140   /// weight set. It is not propagated through the domination line either.
    141   DEFAULT = 0xfffff
    142 };
    143 
    144 BranchProbabilityInfo::SccInfo::SccInfo(const Function &F) {
    145   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
    146   // FIXME: We could only calculate this if the CFG is known to be irreducible
    147   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
    148   int SccNum = 0;
    149   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
    150        ++It, ++SccNum) {
    151     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
    152     // catch them.
    153     const std::vector<const BasicBlock *> &Scc = *It;
    154     if (Scc.size() == 1)
    155       continue;
    156 
    157     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
    158     for (const auto *BB : Scc) {
    159       LLVM_DEBUG(dbgs() << " " << BB->getName());
    160       SccNums[BB] = SccNum;
    161       calculateSccBlockType(BB, SccNum);
    162     }
    163     LLVM_DEBUG(dbgs() << "\n");
    164   }
    165 }
    166 
    167 int BranchProbabilityInfo::SccInfo::getSCCNum(const BasicBlock *BB) const {
    168   auto SccIt = SccNums.find(BB);
    169   if (SccIt == SccNums.end())
    170     return -1;
    171   return SccIt->second;
    172 }
    173 
    174 void BranchProbabilityInfo::SccInfo::getSccEnterBlocks(
    175     int SccNum, SmallVectorImpl<BasicBlock *> &Enters) const {
    176 
    177   for (auto MapIt : SccBlocks[SccNum]) {
    178     const auto *BB = MapIt.first;
    179     if (isSCCHeader(BB, SccNum))
    180       for (const auto *Pred : predecessors(BB))
    181         if (getSCCNum(Pred) != SccNum)
    182           Enters.push_back(const_cast<BasicBlock *>(BB));
    183   }
    184 }
    185 
    186 void BranchProbabilityInfo::SccInfo::getSccExitBlocks(
    187     int SccNum, SmallVectorImpl<BasicBlock *> &Exits) const {
    188   for (auto MapIt : SccBlocks[SccNum]) {
    189     const auto *BB = MapIt.first;
    190     if (isSCCExitingBlock(BB, SccNum))
    191       for (const auto *Succ : successors(BB))
    192         if (getSCCNum(Succ) != SccNum)
    193           Exits.push_back(const_cast<BasicBlock *>(BB));
    194   }
    195 }
    196 
    197 uint32_t BranchProbabilityInfo::SccInfo::getSccBlockType(const BasicBlock *BB,
    198                                                          int SccNum) const {
    199   assert(getSCCNum(BB) == SccNum);
    200 
    201   assert(SccBlocks.size() > static_cast<unsigned>(SccNum) && "Unknown SCC");
    202   const auto &SccBlockTypes = SccBlocks[SccNum];
    203 
    204   auto It = SccBlockTypes.find(BB);
    205   if (It != SccBlockTypes.end()) {
    206     return It->second;
    207   }
    208   return Inner;
    209 }
    210 
    211 void BranchProbabilityInfo::SccInfo::calculateSccBlockType(const BasicBlock *BB,
    212                                                            int SccNum) {
    213   assert(getSCCNum(BB) == SccNum);
    214   uint32_t BlockType = Inner;
    215 
    216   if (llvm::any_of(predecessors(BB), [&](const BasicBlock *Pred) {
    217         // Consider any block that is an entry point to the SCC as
    218         // a header.
    219         return getSCCNum(Pred) != SccNum;
    220       }))
    221     BlockType |= Header;
    222 
    223   if (llvm::any_of(successors(BB), [&](const BasicBlock *Succ) {
    224         return getSCCNum(Succ) != SccNum;
    225       }))
    226     BlockType |= Exiting;
    227 
    228   // Lazily compute the set of headers for a given SCC and cache the results
    229   // in the SccHeaderMap.
    230   if (SccBlocks.size() <= static_cast<unsigned>(SccNum))
    231     SccBlocks.resize(SccNum + 1);
    232   auto &SccBlockTypes = SccBlocks[SccNum];
    233 
    234   if (BlockType != Inner) {
    235     bool IsInserted;
    236     std::tie(std::ignore, IsInserted) =
    237         SccBlockTypes.insert(std::make_pair(BB, BlockType));
    238     assert(IsInserted && "Duplicated block in SCC");
    239   }
    240 }
    241 
    242 BranchProbabilityInfo::LoopBlock::LoopBlock(const BasicBlock *BB,
    243                                             const LoopInfo &LI,
    244                                             const SccInfo &SccI)
    245     : BB(BB) {
    246   LD.first = LI.getLoopFor(BB);
    247   if (!LD.first) {
    248     LD.second = SccI.getSCCNum(BB);
    249   }
    250 }
    251 
    252 bool BranchProbabilityInfo::isLoopEnteringEdge(const LoopEdge &Edge) const {
    253   const auto &SrcBlock = Edge.first;
    254   const auto &DstBlock = Edge.second;
    255   return (DstBlock.getLoop() &&
    256           !DstBlock.getLoop()->contains(SrcBlock.getLoop())) ||
    257          // Assume that SCCs can't be nested.
    258          (DstBlock.getSccNum() != -1 &&
    259           SrcBlock.getSccNum() != DstBlock.getSccNum());
    260 }
    261 
    262 bool BranchProbabilityInfo::isLoopExitingEdge(const LoopEdge &Edge) const {
    263   return isLoopEnteringEdge({Edge.second, Edge.first});
    264 }
    265 
    266 bool BranchProbabilityInfo::isLoopEnteringExitingEdge(
    267     const LoopEdge &Edge) const {
    268   return isLoopEnteringEdge(Edge) || isLoopExitingEdge(Edge);
    269 }
    270 
    271 bool BranchProbabilityInfo::isLoopBackEdge(const LoopEdge &Edge) const {
    272   const auto &SrcBlock = Edge.first;
    273   const auto &DstBlock = Edge.second;
    274   return SrcBlock.belongsToSameLoop(DstBlock) &&
    275          ((DstBlock.getLoop() &&
    276            DstBlock.getLoop()->getHeader() == DstBlock.getBlock()) ||
    277           (DstBlock.getSccNum() != -1 &&
    278            SccI->isSCCHeader(DstBlock.getBlock(), DstBlock.getSccNum())));
    279 }
    280 
    281 void BranchProbabilityInfo::getLoopEnterBlocks(
    282     const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Enters) const {
    283   if (LB.getLoop()) {
    284     auto *Header = LB.getLoop()->getHeader();
    285     Enters.append(pred_begin(Header), pred_end(Header));
    286   } else {
    287     assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
    288     SccI->getSccEnterBlocks(LB.getSccNum(), Enters);
    289   }
    290 }
    291 
    292 void BranchProbabilityInfo::getLoopExitBlocks(
    293     const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Exits) const {
    294   if (LB.getLoop()) {
    295     LB.getLoop()->getExitBlocks(Exits);
    296   } else {
    297     assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
    298     SccI->getSccExitBlocks(LB.getSccNum(), Exits);
    299   }
    300 }
    301 
    302 // Propagate existing explicit probabilities from either profile data or
    303 // 'expect' intrinsic processing. Examine metadata against unreachable
    304 // heuristic. The probability of the edge coming to unreachable block is
    305 // set to min of metadata and unreachable heuristic.
    306 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
    307   const Instruction *TI = BB->getTerminator();
    308   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
    309   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI) ||
    310         isa<InvokeInst>(TI)))
    311     return false;
    312 
    313   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
    314   if (!WeightsNode)
    315     return false;
    316 
    317   // Check that the number of successors is manageable.
    318   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
    319 
    320   // Ensure there are weights for all of the successors. Note that the first
    321   // operand to the metadata node is a name, not a weight.
    322   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
    323     return false;
    324 
    325   // Build up the final weights that will be used in a temporary buffer.
    326   // Compute the sum of all weights to later decide whether they need to
    327   // be scaled to fit in 32 bits.
    328   uint64_t WeightSum = 0;
    329   SmallVector<uint32_t, 2> Weights;
    330   SmallVector<unsigned, 2> UnreachableIdxs;
    331   SmallVector<unsigned, 2> ReachableIdxs;
    332   Weights.reserve(TI->getNumSuccessors());
    333   for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
    334     ConstantInt *Weight =
    335         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
    336     if (!Weight)
    337       return false;
    338     assert(Weight->getValue().getActiveBits() <= 32 &&
    339            "Too many bits for uint32_t");
    340     Weights.push_back(Weight->getZExtValue());
    341     WeightSum += Weights.back();
    342     const LoopBlock SrcLoopBB = getLoopBlock(BB);
    343     const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1));
    344     auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
    345     if (EstimatedWeight &&
    346         EstimatedWeight.getValue() <=
    347             static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
    348       UnreachableIdxs.push_back(I - 1);
    349     else
    350       ReachableIdxs.push_back(I - 1);
    351   }
    352   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
    353 
    354   // If the sum of weights does not fit in 32 bits, scale every weight down
    355   // accordingly.
    356   uint64_t ScalingFactor =
    357       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
    358 
    359   if (ScalingFactor > 1) {
    360     WeightSum = 0;
    361     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
    362       Weights[I] /= ScalingFactor;
    363       WeightSum += Weights[I];
    364     }
    365   }
    366   assert(WeightSum <= UINT32_MAX &&
    367          "Expected weights to scale down to 32 bits");
    368 
    369   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
    370     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
    371       Weights[I] = 1;
    372     WeightSum = TI->getNumSuccessors();
    373   }
    374 
    375   // Set the probability.
    376   SmallVector<BranchProbability, 2> BP;
    377   for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
    378     BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
    379 
    380   // Examine the metadata against unreachable heuristic.
    381   // If the unreachable heuristic is more strong then we use it for this edge.
    382   if (UnreachableIdxs.size() == 0 || ReachableIdxs.size() == 0) {
    383     setEdgeProbability(BB, BP);
    384     return true;
    385   }
    386 
    387   auto UnreachableProb = UR_TAKEN_PROB;
    388   for (auto I : UnreachableIdxs)
    389     if (UnreachableProb < BP[I]) {
    390       BP[I] = UnreachableProb;
    391     }
    392 
    393   // Sum of all edge probabilities must be 1.0. If we modified the probability
    394   // of some edges then we must distribute the introduced difference over the
    395   // reachable blocks.
    396   //
    397   // Proportional distribution: the relation between probabilities of the
    398   // reachable edges is kept unchanged. That is for any reachable edges i and j:
    399   //   newBP[i] / newBP[j] == oldBP[i] / oldBP[j] =>
    400   //   newBP[i] / oldBP[i] == newBP[j] / oldBP[j] == K
    401   // Where K is independent of i,j.
    402   //   newBP[i] == oldBP[i] * K
    403   // We need to find K.
    404   // Make sum of all reachables of the left and right parts:
    405   //   sum_of_reachable(newBP) == K * sum_of_reachable(oldBP)
    406   // Sum of newBP must be equal to 1.0:
    407   //   sum_of_reachable(newBP) + sum_of_unreachable(newBP) == 1.0 =>
    408   //   sum_of_reachable(newBP) = 1.0 - sum_of_unreachable(newBP)
    409   // Where sum_of_unreachable(newBP) is what has been just changed.
    410   // Finally:
    411   //   K == sum_of_reachable(newBP) / sum_of_reachable(oldBP) =>
    412   //   K == (1.0 - sum_of_unreachable(newBP)) / sum_of_reachable(oldBP)
    413   BranchProbability NewUnreachableSum = BranchProbability::getZero();
    414   for (auto I : UnreachableIdxs)
    415     NewUnreachableSum += BP[I];
    416 
    417   BranchProbability NewReachableSum =
    418       BranchProbability::getOne() - NewUnreachableSum;
    419 
    420   BranchProbability OldReachableSum = BranchProbability::getZero();
    421   for (auto I : ReachableIdxs)
    422     OldReachableSum += BP[I];
    423 
    424   if (OldReachableSum != NewReachableSum) { // Anything to dsitribute?
    425     if (OldReachableSum.isZero()) {
    426       // If all oldBP[i] are zeroes then the proportional distribution results
    427       // in all zero probabilities and the error stays big. In this case we
    428       // evenly spread NewReachableSum over the reachable edges.
    429       BranchProbability PerEdge = NewReachableSum / ReachableIdxs.size();
    430       for (auto I : ReachableIdxs)
    431         BP[I] = PerEdge;
    432     } else {
    433       for (auto I : ReachableIdxs) {
    434         // We use uint64_t to avoid double rounding error of the following
    435         // calculation: BP[i] = BP[i] * NewReachableSum / OldReachableSum
    436         // The formula is taken from the private constructor
    437         // BranchProbability(uint32_t Numerator, uint32_t Denominator)
    438         uint64_t Mul = static_cast<uint64_t>(NewReachableSum.getNumerator()) *
    439                        BP[I].getNumerator();
    440         uint32_t Div = static_cast<uint32_t>(
    441             divideNearest(Mul, OldReachableSum.getNumerator()));
    442         BP[I] = BranchProbability::getRaw(Div);
    443       }
    444     }
    445   }
    446 
    447   setEdgeProbability(BB, BP);
    448 
    449   return true;
    450 }
    451 
    452 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
    453 // between two pointer or pointer and NULL will fail.
    454 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
    455   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
    456   if (!BI || !BI->isConditional())
    457     return false;
    458 
    459   Value *Cond = BI->getCondition();
    460   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
    461   if (!CI || !CI->isEquality())
    462     return false;
    463 
    464   Value *LHS = CI->getOperand(0);
    465 
    466   if (!LHS->getType()->isPointerTy())
    467     return false;
    468 
    469   assert(CI->getOperand(1)->getType()->isPointerTy());
    470 
    471   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
    472                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
    473   BranchProbability UntakenProb(PH_NONTAKEN_WEIGHT,
    474                                 PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
    475 
    476   // p != 0   ->   isProb = true
    477   // p == 0   ->   isProb = false
    478   // p != q   ->   isProb = true
    479   // p == q   ->   isProb = false;
    480   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
    481   if (!isProb)
    482     std::swap(TakenProb, UntakenProb);
    483 
    484   setEdgeProbability(
    485       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
    486   return true;
    487 }
    488 
    489 // Compute the unlikely successors to the block BB in the loop L, specifically
    490 // those that are unlikely because this is a loop, and add them to the
    491 // UnlikelyBlocks set.
    492 static void
    493 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
    494                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
    495   // Sometimes in a loop we have a branch whose condition is made false by
    496   // taking it. This is typically something like
    497   //  int n = 0;
    498   //  while (...) {
    499   //    if (++n >= MAX) {
    500   //      n = 0;
    501   //    }
    502   //  }
    503   // In this sort of situation taking the branch means that at the very least it
    504   // won't be taken again in the next iteration of the loop, so we should
    505   // consider it less likely than a typical branch.
    506   //
    507   // We detect this by looking back through the graph of PHI nodes that sets the
    508   // value that the condition depends on, and seeing if we can reach a successor
    509   // block which can be determined to make the condition false.
    510   //
    511   // FIXME: We currently consider unlikely blocks to be half as likely as other
    512   // blocks, but if we consider the example above the likelyhood is actually
    513   // 1/MAX. We could therefore be more precise in how unlikely we consider
    514   // blocks to be, but it would require more careful examination of the form
    515   // of the comparison expression.
    516   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
    517   if (!BI || !BI->isConditional())
    518     return;
    519 
    520   // Check if the branch is based on an instruction compared with a constant
    521   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
    522   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
    523       !isa<Constant>(CI->getOperand(1)))
    524     return;
    525 
    526   // Either the instruction must be a PHI, or a chain of operations involving
    527   // constants that ends in a PHI which we can then collapse into a single value
    528   // if the PHI value is known.
    529   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
    530   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
    531   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
    532   // Collect the instructions until we hit a PHI
    533   SmallVector<BinaryOperator *, 1> InstChain;
    534   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
    535          isa<Constant>(CmpLHS->getOperand(1))) {
    536     // Stop if the chain extends outside of the loop
    537     if (!L->contains(CmpLHS))
    538       return;
    539     InstChain.push_back(cast<BinaryOperator>(CmpLHS));
    540     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
    541     if (CmpLHS)
    542       CmpPHI = dyn_cast<PHINode>(CmpLHS);
    543   }
    544   if (!CmpPHI || !L->contains(CmpPHI))
    545     return;
    546 
    547   // Trace the phi node to find all values that come from successors of BB
    548   SmallPtrSet<PHINode*, 8> VisitedInsts;
    549   SmallVector<PHINode*, 8> WorkList;
    550   WorkList.push_back(CmpPHI);
    551   VisitedInsts.insert(CmpPHI);
    552   while (!WorkList.empty()) {
    553     PHINode *P = WorkList.pop_back_val();
    554     for (BasicBlock *B : P->blocks()) {
    555       // Skip blocks that aren't part of the loop
    556       if (!L->contains(B))
    557         continue;
    558       Value *V = P->getIncomingValueForBlock(B);
    559       // If the source is a PHI add it to the work list if we haven't
    560       // already visited it.
    561       if (PHINode *PN = dyn_cast<PHINode>(V)) {
    562         if (VisitedInsts.insert(PN).second)
    563           WorkList.push_back(PN);
    564         continue;
    565       }
    566       // If this incoming value is a constant and B is a successor of BB, then
    567       // we can constant-evaluate the compare to see if it makes the branch be
    568       // taken or not.
    569       Constant *CmpLHSConst = dyn_cast<Constant>(V);
    570       if (!CmpLHSConst || !llvm::is_contained(successors(BB), B))
    571         continue;
    572       // First collapse InstChain
    573       for (Instruction *I : llvm::reverse(InstChain)) {
    574         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
    575                                         cast<Constant>(I->getOperand(1)), true);
    576         if (!CmpLHSConst)
    577           break;
    578       }
    579       if (!CmpLHSConst)
    580         continue;
    581       // Now constant-evaluate the compare
    582       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
    583                                                   CmpLHSConst, CmpConst, true);
    584       // If the result means we don't branch to the block then that block is
    585       // unlikely.
    586       if (Result &&
    587           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
    588            (Result->isOneValue() && B == BI->getSuccessor(1))))
    589         UnlikelyBlocks.insert(B);
    590     }
    591   }
    592 }
    593 
    594 Optional<uint32_t>
    595 BranchProbabilityInfo::getEstimatedBlockWeight(const BasicBlock *BB) const {
    596   auto WeightIt = EstimatedBlockWeight.find(BB);
    597   if (WeightIt == EstimatedBlockWeight.end())
    598     return None;
    599   return WeightIt->second;
    600 }
    601 
    602 Optional<uint32_t>
    603 BranchProbabilityInfo::getEstimatedLoopWeight(const LoopData &L) const {
    604   auto WeightIt = EstimatedLoopWeight.find(L);
    605   if (WeightIt == EstimatedLoopWeight.end())
    606     return None;
    607   return WeightIt->second;
    608 }
    609 
    610 Optional<uint32_t>
    611 BranchProbabilityInfo::getEstimatedEdgeWeight(const LoopEdge &Edge) const {
    612   // For edges entering a loop take weight of a loop rather than an individual
    613   // block in the loop.
    614   return isLoopEnteringEdge(Edge)
    615              ? getEstimatedLoopWeight(Edge.second.getLoopData())
    616              : getEstimatedBlockWeight(Edge.second.getBlock());
    617 }
    618 
    619 template <class IterT>
    620 Optional<uint32_t> BranchProbabilityInfo::getMaxEstimatedEdgeWeight(
    621     const LoopBlock &SrcLoopBB, iterator_range<IterT> Successors) const {
    622   SmallVector<uint32_t, 4> Weights;
    623   Optional<uint32_t> MaxWeight;
    624   for (const BasicBlock *DstBB : Successors) {
    625     const LoopBlock DstLoopBB = getLoopBlock(DstBB);
    626     auto Weight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
    627 
    628     if (!Weight)
    629       return None;
    630 
    631     if (!MaxWeight || MaxWeight.getValue() < Weight.getValue())
    632       MaxWeight = Weight;
    633   }
    634 
    635   return MaxWeight;
    636 }
    637 
    638 // Updates \p LoopBB's weight and returns true. If \p LoopBB has already
    639 // an associated weight it is unchanged and false is returned.
    640 //
    641 // Please note by the algorithm the weight is not expected to change once set
    642 // thus 'false' status is used to track visited blocks.
    643 bool BranchProbabilityInfo::updateEstimatedBlockWeight(
    644     LoopBlock &LoopBB, uint32_t BBWeight,
    645     SmallVectorImpl<BasicBlock *> &BlockWorkList,
    646     SmallVectorImpl<LoopBlock> &LoopWorkList) {
    647   BasicBlock *BB = LoopBB.getBlock();
    648 
    649   // In general, weight is assigned to a block when it has final value and
    650   // can't/shouldn't be changed.  However, there are cases when a block
    651   // inherently has several (possibly "contradicting") weights. For example,
    652   // "unwind" block may also contain "cold" call. In that case the first
    653   // set weight is favored and all consequent weights are ignored.
    654   if (!EstimatedBlockWeight.insert({BB, BBWeight}).second)
    655     return false;
    656 
    657   for (BasicBlock *PredBlock : predecessors(BB)) {
    658     LoopBlock PredLoop = getLoopBlock(PredBlock);
    659     // Add affected block/loop to a working list.
    660     if (isLoopExitingEdge({PredLoop, LoopBB})) {
    661       if (!EstimatedLoopWeight.count(PredLoop.getLoopData()))
    662         LoopWorkList.push_back(PredLoop);
    663     } else if (!EstimatedBlockWeight.count(PredBlock))
    664       BlockWorkList.push_back(PredBlock);
    665   }
    666   return true;
    667 }
    668 
    669 // Starting from \p BB traverse through dominator blocks and assign \p BBWeight
    670 // to all such blocks that are post dominated by \BB. In other words to all
    671 // blocks that the one is executed if and only if another one is executed.
    672 // Importantly, we skip loops here for two reasons. First weights of blocks in
    673 // a loop should be scaled by trip count (yet possibly unknown). Second there is
    674 // no any value in doing that because that doesn't give any additional
    675 // information regarding distribution of probabilities inside the loop.
    676 // Exception is loop 'enter' and 'exit' edges that are handled in a special way
    677 // at calcEstimatedHeuristics.
    678 //
    679 // In addition, \p WorkList is populated with basic blocks if at leas one
    680 // successor has updated estimated weight.
    681 void BranchProbabilityInfo::propagateEstimatedBlockWeight(
    682     const LoopBlock &LoopBB, DominatorTree *DT, PostDominatorTree *PDT,
    683     uint32_t BBWeight, SmallVectorImpl<BasicBlock *> &BlockWorkList,
    684     SmallVectorImpl<LoopBlock> &LoopWorkList) {
    685   const BasicBlock *BB = LoopBB.getBlock();
    686   const auto *DTStartNode = DT->getNode(BB);
    687   const auto *PDTStartNode = PDT->getNode(BB);
    688 
    689   // TODO: Consider propagating weight down the domination line as well.
    690   for (const auto *DTNode = DTStartNode; DTNode != nullptr;
    691        DTNode = DTNode->getIDom()) {
    692     auto *DomBB = DTNode->getBlock();
    693     // Consider blocks which lie on one 'line'.
    694     if (!PDT->dominates(PDTStartNode, PDT->getNode(DomBB)))
    695       // If BB doesn't post dominate DomBB it will not post dominate dominators
    696       // of DomBB as well.
    697       break;
    698 
    699     LoopBlock DomLoopBB = getLoopBlock(DomBB);
    700     const LoopEdge Edge{DomLoopBB, LoopBB};
    701     // Don't propagate weight to blocks belonging to different loops.
    702     if (!isLoopEnteringExitingEdge(Edge)) {
    703       if (!updateEstimatedBlockWeight(DomLoopBB, BBWeight, BlockWorkList,
    704                                       LoopWorkList))
    705         // If DomBB has weight set then all it's predecessors are already
    706         // processed (since we propagate weight up to the top of IR each time).
    707         break;
    708     } else if (isLoopExitingEdge(Edge)) {
    709       LoopWorkList.push_back(DomLoopBB);
    710     }
    711   }
    712 }
    713 
    714 Optional<uint32_t> BranchProbabilityInfo::getInitialEstimatedBlockWeight(
    715     const BasicBlock *BB) {
    716   // Returns true if \p BB has call marked with "NoReturn" attribute.
    717   auto hasNoReturn = [&](const BasicBlock *BB) {
    718     for (const auto &I : reverse(*BB))
    719       if (const CallInst *CI = dyn_cast<CallInst>(&I))
    720         if (CI->hasFnAttr(Attribute::NoReturn))
    721           return true;
    722 
    723     return false;
    724   };
    725 
    726   // Important note regarding the order of checks. They are ordered by weight
    727   // from lowest to highest. Doing that allows to avoid "unstable" results
    728   // when several conditions heuristics can be applied simultaneously.
    729   if (isa<UnreachableInst>(BB->getTerminator()) ||
    730       // If this block is terminated by a call to
    731       // @llvm.experimental.deoptimize then treat it like an unreachable
    732       // since it is expected to practically never execute.
    733       // TODO: Should we actually treat as never returning call?
    734       BB->getTerminatingDeoptimizeCall())
    735     return hasNoReturn(BB)
    736                ? static_cast<uint32_t>(BlockExecWeight::NORETURN)
    737                : static_cast<uint32_t>(BlockExecWeight::UNREACHABLE);
    738 
    739   // Check if the block is 'unwind' handler of  some invoke instruction.
    740   for (const auto *Pred : predecessors(BB))
    741     if (Pred)
    742       if (const auto *II = dyn_cast<InvokeInst>(Pred->getTerminator()))
    743         if (II->getUnwindDest() == BB)
    744           return static_cast<uint32_t>(BlockExecWeight::UNWIND);
    745 
    746   // Check if the block contains 'cold' call.
    747   for (const auto &I : *BB)
    748     if (const CallInst *CI = dyn_cast<CallInst>(&I))
    749       if (CI->hasFnAttr(Attribute::Cold))
    750         return static_cast<uint32_t>(BlockExecWeight::COLD);
    751 
    752   return None;
    753 }
    754 
    755 // Does RPO traversal over all blocks in \p F and assigns weights to
    756 // 'unreachable', 'noreturn', 'cold', 'unwind' blocks. In addition it does its
    757 // best to propagate the weight to up/down the IR.
    758 void BranchProbabilityInfo::computeEestimateBlockWeight(
    759     const Function &F, DominatorTree *DT, PostDominatorTree *PDT) {
    760   SmallVector<BasicBlock *, 8> BlockWorkList;
    761   SmallVector<LoopBlock, 8> LoopWorkList;
    762 
    763   // By doing RPO we make sure that all predecessors already have weights
    764   // calculated before visiting theirs successors.
    765   ReversePostOrderTraversal<const Function *> RPOT(&F);
    766   for (const auto *BB : RPOT)
    767     if (auto BBWeight = getInitialEstimatedBlockWeight(BB))
    768       // If we were able to find estimated weight for the block set it to this
    769       // block and propagate up the IR.
    770       propagateEstimatedBlockWeight(getLoopBlock(BB), DT, PDT,
    771                                     BBWeight.getValue(), BlockWorkList,
    772                                     LoopWorkList);
    773 
    774   // BlockWorklist/LoopWorkList contains blocks/loops with at least one
    775   // successor/exit having estimated weight. Try to propagate weight to such
    776   // blocks/loops from successors/exits.
    777   // Process loops and blocks. Order is not important.
    778   do {
    779     while (!LoopWorkList.empty()) {
    780       const LoopBlock LoopBB = LoopWorkList.pop_back_val();
    781 
    782       if (EstimatedLoopWeight.count(LoopBB.getLoopData()))
    783         continue;
    784 
    785       SmallVector<BasicBlock *, 4> Exits;
    786       getLoopExitBlocks(LoopBB, Exits);
    787       auto LoopWeight = getMaxEstimatedEdgeWeight(
    788           LoopBB, make_range(Exits.begin(), Exits.end()));
    789 
    790       if (LoopWeight) {
    791         // If we never exit the loop then we can enter it once at maximum.
    792         if (LoopWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
    793           LoopWeight = static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO);
    794 
    795         EstimatedLoopWeight.insert(
    796             {LoopBB.getLoopData(), LoopWeight.getValue()});
    797         // Add all blocks entering the loop into working list.
    798         getLoopEnterBlocks(LoopBB, BlockWorkList);
    799       }
    800     }
    801 
    802     while (!BlockWorkList.empty()) {
    803       // We can reach here only if BlockWorkList is not empty.
    804       const BasicBlock *BB = BlockWorkList.pop_back_val();
    805       if (EstimatedBlockWeight.count(BB))
    806         continue;
    807 
    808       // We take maximum over all weights of successors. In other words we take
    809       // weight of "hot" path. In theory we can probably find a better function
    810       // which gives higher accuracy results (comparing to "maximum") but I
    811       // can't
    812       // think of any right now. And I doubt it will make any difference in
    813       // practice.
    814       const LoopBlock LoopBB = getLoopBlock(BB);
    815       auto MaxWeight = getMaxEstimatedEdgeWeight(LoopBB, successors(BB));
    816 
    817       if (MaxWeight)
    818         propagateEstimatedBlockWeight(LoopBB, DT, PDT, MaxWeight.getValue(),
    819                                       BlockWorkList, LoopWorkList);
    820     }
    821   } while (!BlockWorkList.empty() || !LoopWorkList.empty());
    822 }
    823 
    824 // Calculate edge probabilities based on block's estimated weight.
    825 // Note that gathered weights were not scaled for loops. Thus edges entering
    826 // and exiting loops requires special processing.
    827 bool BranchProbabilityInfo::calcEstimatedHeuristics(const BasicBlock *BB) {
    828   assert(BB->getTerminator()->getNumSuccessors() > 1 &&
    829          "expected more than one successor!");
    830 
    831   const LoopBlock LoopBB = getLoopBlock(BB);
    832 
    833   SmallPtrSet<const BasicBlock *, 8> UnlikelyBlocks;
    834   uint32_t TC = LBH_TAKEN_WEIGHT / LBH_NONTAKEN_WEIGHT;
    835   if (LoopBB.getLoop())
    836     computeUnlikelySuccessors(BB, LoopBB.getLoop(), UnlikelyBlocks);
    837 
    838   // Changed to 'true' if at least one successor has estimated weight.
    839   bool FoundEstimatedWeight = false;
    840   SmallVector<uint32_t, 4> SuccWeights;
    841   uint64_t TotalWeight = 0;
    842   // Go over all successors of BB and put their weights into SuccWeights.
    843   for (const BasicBlock *SuccBB : successors(BB)) {
    844     Optional<uint32_t> Weight;
    845     const LoopBlock SuccLoopBB = getLoopBlock(SuccBB);
    846     const LoopEdge Edge{LoopBB, SuccLoopBB};
    847 
    848     Weight = getEstimatedEdgeWeight(Edge);
    849 
    850     if (isLoopExitingEdge(Edge) &&
    851         // Avoid adjustment of ZERO weight since it should remain unchanged.
    852         Weight != static_cast<uint32_t>(BlockExecWeight::ZERO)) {
    853       // Scale down loop exiting weight by trip count.
    854       Weight = std::max(
    855           static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO),
    856           Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) /
    857               TC);
    858     }
    859     bool IsUnlikelyEdge = LoopBB.getLoop() && UnlikelyBlocks.contains(SuccBB);
    860     if (IsUnlikelyEdge &&
    861         // Avoid adjustment of ZERO weight since it should remain unchanged.
    862         Weight != static_cast<uint32_t>(BlockExecWeight::ZERO)) {
    863       // 'Unlikely' blocks have twice lower weight.
    864       Weight = std::max(
    865           static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO),
    866           Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) /
    867               2);
    868     }
    869 
    870     if (Weight)
    871       FoundEstimatedWeight = true;
    872 
    873     auto WeightVal =
    874         Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT));
    875     TotalWeight += WeightVal;
    876     SuccWeights.push_back(WeightVal);
    877   }
    878 
    879   // If non of blocks have estimated weight bail out.
    880   // If TotalWeight is 0 that means weight of each successor is 0 as well and
    881   // equally likely. Bail out early to not deal with devision by zero.
    882   if (!FoundEstimatedWeight || TotalWeight == 0)
    883     return false;
    884 
    885   assert(SuccWeights.size() == succ_size(BB) && "Missed successor?");
    886   const unsigned SuccCount = SuccWeights.size();
    887 
    888   // If the sum of weights does not fit in 32 bits, scale every weight down
    889   // accordingly.
    890   if (TotalWeight > UINT32_MAX) {
    891     uint64_t ScalingFactor = TotalWeight / UINT32_MAX + 1;
    892     TotalWeight = 0;
    893     for (unsigned Idx = 0; Idx < SuccCount; ++Idx) {
    894       SuccWeights[Idx] /= ScalingFactor;
    895       if (SuccWeights[Idx] == static_cast<uint32_t>(BlockExecWeight::ZERO))
    896         SuccWeights[Idx] =
    897             static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO);
    898       TotalWeight += SuccWeights[Idx];
    899     }
    900     assert(TotalWeight <= UINT32_MAX && "Total weight overflows");
    901   }
    902 
    903   // Finally set probabilities to edges according to estimated block weights.
    904   SmallVector<BranchProbability, 4> EdgeProbabilities(
    905       SuccCount, BranchProbability::getUnknown());
    906 
    907   for (unsigned Idx = 0; Idx < SuccCount; ++Idx) {
    908     EdgeProbabilities[Idx] =
    909         BranchProbability(SuccWeights[Idx], (uint32_t)TotalWeight);
    910   }
    911   setEdgeProbability(BB, EdgeProbabilities);
    912   return true;
    913 }
    914 
    915 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
    916                                                const TargetLibraryInfo *TLI) {
    917   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
    918   if (!BI || !BI->isConditional())
    919     return false;
    920 
    921   Value *Cond = BI->getCondition();
    922   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
    923   if (!CI)
    924     return false;
    925 
    926   auto GetConstantInt = [](Value *V) {
    927     if (auto *I = dyn_cast<BitCastInst>(V))
    928       return dyn_cast<ConstantInt>(I->getOperand(0));
    929     return dyn_cast<ConstantInt>(V);
    930   };
    931 
    932   Value *RHS = CI->getOperand(1);
    933   ConstantInt *CV = GetConstantInt(RHS);
    934   if (!CV)
    935     return false;
    936 
    937   // If the LHS is the result of AND'ing a value with a single bit bitmask,
    938   // we don't have information about probabilities.
    939   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
    940     if (LHS->getOpcode() == Instruction::And)
    941       if (ConstantInt *AndRHS = GetConstantInt(LHS->getOperand(1)))
    942         if (AndRHS->getValue().isPowerOf2())
    943           return false;
    944 
    945   // Check if the LHS is the return value of a library function
    946   LibFunc Func = NumLibFuncs;
    947   if (TLI)
    948     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
    949       if (Function *CalledFn = Call->getCalledFunction())
    950         TLI->getLibFunc(*CalledFn, Func);
    951 
    952   bool isProb;
    953   if (Func == LibFunc_strcasecmp ||
    954       Func == LibFunc_strcmp ||
    955       Func == LibFunc_strncasecmp ||
    956       Func == LibFunc_strncmp ||
    957       Func == LibFunc_memcmp ||
    958       Func == LibFunc_bcmp) {
    959     // strcmp and similar functions return zero, negative, or positive, if the
    960     // first string is equal, less, or greater than the second. We consider it
    961     // likely that the strings are not equal, so a comparison with zero is
    962     // probably false, but also a comparison with any other number is also
    963     // probably false given that what exactly is returned for nonzero values is
    964     // not specified. Any kind of comparison other than equality we know
    965     // nothing about.
    966     switch (CI->getPredicate()) {
    967     case CmpInst::ICMP_EQ:
    968       isProb = false;
    969       break;
    970     case CmpInst::ICMP_NE:
    971       isProb = true;
    972       break;
    973     default:
    974       return false;
    975     }
    976   } else if (CV->isZero()) {
    977     switch (CI->getPredicate()) {
    978     case CmpInst::ICMP_EQ:
    979       // X == 0   ->  Unlikely
    980       isProb = false;
    981       break;
    982     case CmpInst::ICMP_NE:
    983       // X != 0   ->  Likely
    984       isProb = true;
    985       break;
    986     case CmpInst::ICMP_SLT:
    987       // X < 0   ->  Unlikely
    988       isProb = false;
    989       break;
    990     case CmpInst::ICMP_SGT:
    991       // X > 0   ->  Likely
    992       isProb = true;
    993       break;
    994     default:
    995       return false;
    996     }
    997   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
    998     // InstCombine canonicalizes X <= 0 into X < 1.
    999     // X <= 0   ->  Unlikely
   1000     isProb = false;
   1001   } else if (CV->isMinusOne()) {
   1002     switch (CI->getPredicate()) {
   1003     case CmpInst::ICMP_EQ:
   1004       // X == -1  ->  Unlikely
   1005       isProb = false;
   1006       break;
   1007     case CmpInst::ICMP_NE:
   1008       // X != -1  ->  Likely
   1009       isProb = true;
   1010       break;
   1011     case CmpInst::ICMP_SGT:
   1012       // InstCombine canonicalizes X >= 0 into X > -1.
   1013       // X >= 0   ->  Likely
   1014       isProb = true;
   1015       break;
   1016     default:
   1017       return false;
   1018     }
   1019   } else {
   1020     return false;
   1021   }
   1022 
   1023   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
   1024                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
   1025   BranchProbability UntakenProb(ZH_NONTAKEN_WEIGHT,
   1026                                 ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
   1027   if (!isProb)
   1028     std::swap(TakenProb, UntakenProb);
   1029 
   1030   setEdgeProbability(
   1031       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
   1032   return true;
   1033 }
   1034 
   1035 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
   1036   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
   1037   if (!BI || !BI->isConditional())
   1038     return false;
   1039 
   1040   Value *Cond = BI->getCondition();
   1041   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
   1042   if (!FCmp)
   1043     return false;
   1044 
   1045   uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
   1046   uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
   1047   bool isProb;
   1048   if (FCmp->isEquality()) {
   1049     // f1 == f2 -> Unlikely
   1050     // f1 != f2 -> Likely
   1051     isProb = !FCmp->isTrueWhenEqual();
   1052   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
   1053     // !isnan -> Likely
   1054     isProb = true;
   1055     TakenWeight = FPH_ORD_WEIGHT;
   1056     NontakenWeight = FPH_UNO_WEIGHT;
   1057   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
   1058     // isnan -> Unlikely
   1059     isProb = false;
   1060     TakenWeight = FPH_ORD_WEIGHT;
   1061     NontakenWeight = FPH_UNO_WEIGHT;
   1062   } else {
   1063     return false;
   1064   }
   1065 
   1066   BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
   1067   BranchProbability UntakenProb(NontakenWeight, TakenWeight + NontakenWeight);
   1068   if (!isProb)
   1069     std::swap(TakenProb, UntakenProb);
   1070 
   1071   setEdgeProbability(
   1072       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
   1073   return true;
   1074 }
   1075 
   1076 void BranchProbabilityInfo::releaseMemory() {
   1077   Probs.clear();
   1078   Handles.clear();
   1079 }
   1080 
   1081 bool BranchProbabilityInfo::invalidate(Function &, const PreservedAnalyses &PA,
   1082                                        FunctionAnalysisManager::Invalidator &) {
   1083   // Check whether the analysis, all analyses on functions, or the function's
   1084   // CFG have been preserved.
   1085   auto PAC = PA.getChecker<BranchProbabilityAnalysis>();
   1086   return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
   1087            PAC.preservedSet<CFGAnalyses>());
   1088 }
   1089 
   1090 void BranchProbabilityInfo::print(raw_ostream &OS) const {
   1091   OS << "---- Branch Probabilities ----\n";
   1092   // We print the probabilities from the last function the analysis ran over,
   1093   // or the function it is currently running over.
   1094   assert(LastF && "Cannot print prior to running over a function");
   1095   for (const auto &BI : *LastF) {
   1096     for (const BasicBlock *Succ : successors(&BI))
   1097       printEdgeProbability(OS << "  ", &BI, Succ);
   1098   }
   1099 }
   1100 
   1101 bool BranchProbabilityInfo::
   1102 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
   1103   // Hot probability is at least 4/5 = 80%
   1104   // FIXME: Compare against a static "hot" BranchProbability.
   1105   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
   1106 }
   1107 
   1108 const BasicBlock *
   1109 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
   1110   auto MaxProb = BranchProbability::getZero();
   1111   const BasicBlock *MaxSucc = nullptr;
   1112 
   1113   for (const auto *Succ : successors(BB)) {
   1114     auto Prob = getEdgeProbability(BB, Succ);
   1115     if (Prob > MaxProb) {
   1116       MaxProb = Prob;
   1117       MaxSucc = Succ;
   1118     }
   1119   }
   1120 
   1121   // Hot probability is at least 4/5 = 80%
   1122   if (MaxProb > BranchProbability(4, 5))
   1123     return MaxSucc;
   1124 
   1125   return nullptr;
   1126 }
   1127 
   1128 /// Get the raw edge probability for the edge. If can't find it, return a
   1129 /// default probability 1/N where N is the number of successors. Here an edge is
   1130 /// specified using PredBlock and an
   1131 /// index to the successors.
   1132 BranchProbability
   1133 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
   1134                                           unsigned IndexInSuccessors) const {
   1135   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
   1136   assert((Probs.end() == Probs.find(std::make_pair(Src, 0))) ==
   1137              (Probs.end() == I) &&
   1138          "Probability for I-th successor must always be defined along with the "
   1139          "probability for the first successor");
   1140 
   1141   if (I != Probs.end())
   1142     return I->second;
   1143 
   1144   return {1, static_cast<uint32_t>(succ_size(Src))};
   1145 }
   1146 
   1147 BranchProbability
   1148 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
   1149                                           const_succ_iterator Dst) const {
   1150   return getEdgeProbability(Src, Dst.getSuccessorIndex());
   1151 }
   1152 
   1153 /// Get the raw edge probability calculated for the block pair. This returns the
   1154 /// sum of all raw edge probabilities from Src to Dst.
   1155 BranchProbability
   1156 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
   1157                                           const BasicBlock *Dst) const {
   1158   if (!Probs.count(std::make_pair(Src, 0)))
   1159     return BranchProbability(llvm::count(successors(Src), Dst), succ_size(Src));
   1160 
   1161   auto Prob = BranchProbability::getZero();
   1162   for (const_succ_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
   1163     if (*I == Dst)
   1164       Prob += Probs.find(std::make_pair(Src, I.getSuccessorIndex()))->second;
   1165 
   1166   return Prob;
   1167 }
   1168 
   1169 /// Set the edge probability for all edges at once.
   1170 void BranchProbabilityInfo::setEdgeProbability(
   1171     const BasicBlock *Src, const SmallVectorImpl<BranchProbability> &Probs) {
   1172   assert(Src->getTerminator()->getNumSuccessors() == Probs.size());
   1173   eraseBlock(Src); // Erase stale data if any.
   1174   if (Probs.size() == 0)
   1175     return; // Nothing to set.
   1176 
   1177   Handles.insert(BasicBlockCallbackVH(Src, this));
   1178   uint64_t TotalNumerator = 0;
   1179   for (unsigned SuccIdx = 0; SuccIdx < Probs.size(); ++SuccIdx) {
   1180     this->Probs[std::make_pair(Src, SuccIdx)] = Probs[SuccIdx];
   1181     LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << SuccIdx
   1182                       << " successor probability to " << Probs[SuccIdx]
   1183                       << "\n");
   1184     TotalNumerator += Probs[SuccIdx].getNumerator();
   1185   }
   1186 
   1187   // Because of rounding errors the total probability cannot be checked to be
   1188   // 1.0 exactly. That is TotalNumerator == BranchProbability::getDenominator.
   1189   // Instead, every single probability in Probs must be as accurate as possible.
   1190   // This results in error 1/denominator at most, thus the total absolute error
   1191   // should be within Probs.size / BranchProbability::getDenominator.
   1192   assert(TotalNumerator <= BranchProbability::getDenominator() + Probs.size());
   1193   assert(TotalNumerator >= BranchProbability::getDenominator() - Probs.size());
   1194 }
   1195 
   1196 void BranchProbabilityInfo::copyEdgeProbabilities(BasicBlock *Src,
   1197                                                   BasicBlock *Dst) {
   1198   eraseBlock(Dst); // Erase stale data if any.
   1199   unsigned NumSuccessors = Src->getTerminator()->getNumSuccessors();
   1200   assert(NumSuccessors == Dst->getTerminator()->getNumSuccessors());
   1201   if (NumSuccessors == 0)
   1202     return; // Nothing to set.
   1203   if (this->Probs.find(std::make_pair(Src, 0)) == this->Probs.end())
   1204     return; // No probability is set for edges from Src. Keep the same for Dst.
   1205 
   1206   Handles.insert(BasicBlockCallbackVH(Dst, this));
   1207   for (unsigned SuccIdx = 0; SuccIdx < NumSuccessors; ++SuccIdx) {
   1208     auto Prob = this->Probs[std::make_pair(Src, SuccIdx)];
   1209     this->Probs[std::make_pair(Dst, SuccIdx)] = Prob;
   1210     LLVM_DEBUG(dbgs() << "set edge " << Dst->getName() << " -> " << SuccIdx
   1211                       << " successor probability to " << Prob << "\n");
   1212   }
   1213 }
   1214 
   1215 raw_ostream &
   1216 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
   1217                                             const BasicBlock *Src,
   1218                                             const BasicBlock *Dst) const {
   1219   const BranchProbability Prob = getEdgeProbability(Src, Dst);
   1220   OS << "edge " << Src->getName() << " -> " << Dst->getName()
   1221      << " probability is " << Prob
   1222      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
   1223 
   1224   return OS;
   1225 }
   1226 
   1227 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
   1228   LLVM_DEBUG(dbgs() << "eraseBlock " << BB->getName() << "\n");
   1229 
   1230   // Note that we cannot use successors of BB because the terminator of BB may
   1231   // have changed when eraseBlock is called as a BasicBlockCallbackVH callback.
   1232   // Instead we remove prob data for the block by iterating successors by their
   1233   // indices from 0 till the last which exists. There could not be prob data for
   1234   // a pair (BB, N) if there is no data for (BB, N-1) because the data is always
   1235   // set for all successors from 0 to M at once by the method
   1236   // setEdgeProbability().
   1237   Handles.erase(BasicBlockCallbackVH(BB, this));
   1238   for (unsigned I = 0;; ++I) {
   1239     auto MapI = Probs.find(std::make_pair(BB, I));
   1240     if (MapI == Probs.end()) {
   1241       assert(Probs.count(std::make_pair(BB, I + 1)) == 0 &&
   1242              "Must be no more successors");
   1243       return;
   1244     }
   1245     Probs.erase(MapI);
   1246   }
   1247 }
   1248 
   1249 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LoopI,
   1250                                       const TargetLibraryInfo *TLI,
   1251                                       DominatorTree *DT,
   1252                                       PostDominatorTree *PDT) {
   1253   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
   1254                     << " ----\n\n");
   1255   LastF = &F; // Store the last function we ran on for printing.
   1256   LI = &LoopI;
   1257 
   1258   SccI = std::make_unique<SccInfo>(F);
   1259 
   1260   assert(EstimatedBlockWeight.empty());
   1261   assert(EstimatedLoopWeight.empty());
   1262 
   1263   std::unique_ptr<DominatorTree> DTPtr;
   1264   std::unique_ptr<PostDominatorTree> PDTPtr;
   1265 
   1266   if (!DT) {
   1267     DTPtr = std::make_unique<DominatorTree>(const_cast<Function &>(F));
   1268     DT = DTPtr.get();
   1269   }
   1270 
   1271   if (!PDT) {
   1272     PDTPtr = std::make_unique<PostDominatorTree>(const_cast<Function &>(F));
   1273     PDT = PDTPtr.get();
   1274   }
   1275 
   1276   computeEestimateBlockWeight(F, DT, PDT);
   1277 
   1278   // Walk the basic blocks in post-order so that we can build up state about
   1279   // the successors of a block iteratively.
   1280   for (auto BB : post_order(&F.getEntryBlock())) {
   1281     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
   1282                       << "\n");
   1283     // If there is no at least two successors, no sense to set probability.
   1284     if (BB->getTerminator()->getNumSuccessors() < 2)
   1285       continue;
   1286     if (calcMetadataWeights(BB))
   1287       continue;
   1288     if (calcEstimatedHeuristics(BB))
   1289       continue;
   1290     if (calcPointerHeuristics(BB))
   1291       continue;
   1292     if (calcZeroHeuristics(BB, TLI))
   1293       continue;
   1294     if (calcFloatingPointHeuristics(BB))
   1295       continue;
   1296   }
   1297 
   1298   EstimatedLoopWeight.clear();
   1299   EstimatedBlockWeight.clear();
   1300   SccI.reset();
   1301 
   1302   if (PrintBranchProb &&
   1303       (PrintBranchProbFuncName.empty() ||
   1304        F.getName().equals(PrintBranchProbFuncName))) {
   1305     print(dbgs());
   1306   }
   1307 }
   1308 
   1309 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
   1310     AnalysisUsage &AU) const {
   1311   // We require DT so it's available when LI is available. The LI updating code
   1312   // asserts that DT is also present so if we don't make sure that we have DT
   1313   // here, that assert will trigger.
   1314   AU.addRequired<DominatorTreeWrapperPass>();
   1315   AU.addRequired<LoopInfoWrapperPass>();
   1316   AU.addRequired<TargetLibraryInfoWrapperPass>();
   1317   AU.addRequired<DominatorTreeWrapperPass>();
   1318   AU.addRequired<PostDominatorTreeWrapperPass>();
   1319   AU.setPreservesAll();
   1320 }
   1321 
   1322 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
   1323   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
   1324   const TargetLibraryInfo &TLI =
   1325       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
   1326   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
   1327   PostDominatorTree &PDT =
   1328       getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
   1329   BPI.calculate(F, LI, &TLI, &DT, &PDT);
   1330   return false;
   1331 }
   1332 
   1333 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
   1334 
   1335 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
   1336                                              const Module *) const {
   1337   BPI.print(OS);
   1338 }
   1339 
   1340 AnalysisKey BranchProbabilityAnalysis::Key;
   1341 BranchProbabilityInfo
   1342 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
   1343   BranchProbabilityInfo BPI;
   1344   BPI.calculate(F, AM.getResult<LoopAnalysis>(F),
   1345                 &AM.getResult<TargetLibraryAnalysis>(F),
   1346                 &AM.getResult<DominatorTreeAnalysis>(F),
   1347                 &AM.getResult<PostDominatorTreeAnalysis>(F));
   1348   return BPI;
   1349 }
   1350 
   1351 PreservedAnalyses
   1352 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
   1353   OS << "Printing analysis results of BPI for function "
   1354      << "'" << F.getName() << "':"
   1355      << "\n";
   1356   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
   1357   return PreservedAnalyses::all();
   1358 }
   1359