Home | History | Annotate | Line # | Download | only in Support
      1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // Definition of BranchProbability shared by IR and Machine Instructions.
     10 //
     11 //===----------------------------------------------------------------------===//
     12 
     13 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
     14 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
     15 
     16 #include "llvm/Support/DataTypes.h"
     17 #include <algorithm>
     18 #include <cassert>
     19 #include <climits>
     20 #include <numeric>
     21 
     22 namespace llvm {
     23 
     24 class raw_ostream;
     25 
     26 // This class represents Branch Probability as a non-negative fraction that is
     27 // no greater than 1. It uses a fixed-point-like implementation, in which the
     28 // denominator is always a constant value (here we use 1<<31 for maximum
     29 // precision).
     30 class BranchProbability {
     31   // Numerator
     32   uint32_t N;
     33 
     34   // Denominator, which is a constant value.
     35   static constexpr uint32_t D = 1u << 31;
     36   static constexpr uint32_t UnknownN = UINT32_MAX;
     37 
     38   // Construct a BranchProbability with only numerator assuming the denominator
     39   // is 1<<31. For internal use only.
     40   explicit BranchProbability(uint32_t n) : N(n) {}
     41 
     42 public:
     43   BranchProbability() : N(UnknownN) {}
     44   BranchProbability(uint32_t Numerator, uint32_t Denominator);
     45 
     46   bool isZero() const { return N == 0; }
     47   bool isUnknown() const { return N == UnknownN; }
     48 
     49   static BranchProbability getZero() { return BranchProbability(0); }
     50   static BranchProbability getOne() { return BranchProbability(D); }
     51   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
     52   // Create a BranchProbability object with the given numerator and 1<<31
     53   // as denominator.
     54   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
     55   // Create a BranchProbability object from 64-bit integers.
     56   static BranchProbability getBranchProbability(uint64_t Numerator,
     57                                                 uint64_t Denominator);
     58 
     59   // Normalize given probabilties so that the sum of them becomes approximate
     60   // one.
     61   template <class ProbabilityIter>
     62   static void normalizeProbabilities(ProbabilityIter Begin,
     63                                      ProbabilityIter End);
     64 
     65   uint32_t getNumerator() const { return N; }
     66   static uint32_t getDenominator() { return D; }
     67 
     68   // Return (1 - Probability).
     69   BranchProbability getCompl() const { return BranchProbability(D - N); }
     70 
     71   raw_ostream &print(raw_ostream &OS) const;
     72 
     73   void dump() const;
     74 
     75   /// Scale a large integer.
     76   ///
     77   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
     78   /// result.
     79   ///
     80   /// \return \c Num times \c this.
     81   uint64_t scale(uint64_t Num) const;
     82 
     83   /// Scale a large integer by the inverse.
     84   ///
     85   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
     86   /// Returns the floor of the result.
     87   ///
     88   /// \return \c Num divided by \c this.
     89   uint64_t scaleByInverse(uint64_t Num) const;
     90 
     91   BranchProbability &operator+=(BranchProbability RHS) {
     92     assert(N != UnknownN && RHS.N != UnknownN &&
     93            "Unknown probability cannot participate in arithmetics.");
     94     // Saturate the result in case of overflow.
     95     N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
     96     return *this;
     97   }
     98 
     99   BranchProbability &operator-=(BranchProbability RHS) {
    100     assert(N != UnknownN && RHS.N != UnknownN &&
    101            "Unknown probability cannot participate in arithmetics.");
    102     // Saturate the result in case of underflow.
    103     N = N < RHS.N ? 0 : N - RHS.N;
    104     return *this;
    105   }
    106 
    107   BranchProbability &operator*=(BranchProbability RHS) {
    108     assert(N != UnknownN && RHS.N != UnknownN &&
    109            "Unknown probability cannot participate in arithmetics.");
    110     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
    111     return *this;
    112   }
    113 
    114   BranchProbability &operator*=(uint32_t RHS) {
    115     assert(N != UnknownN &&
    116            "Unknown probability cannot participate in arithmetics.");
    117     N = (uint64_t(N) * RHS > D) ? D : N * RHS;
    118     return *this;
    119   }
    120 
    121   BranchProbability &operator/=(BranchProbability RHS) {
    122     assert(N != UnknownN && RHS.N != UnknownN &&
    123            "Unknown probability cannot participate in arithmetics.");
    124     N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N;
    125     return *this;
    126   }
    127 
    128   BranchProbability &operator/=(uint32_t RHS) {
    129     assert(N != UnknownN &&
    130            "Unknown probability cannot participate in arithmetics.");
    131     assert(RHS > 0 && "The divider cannot be zero.");
    132     N /= RHS;
    133     return *this;
    134   }
    135 
    136   BranchProbability operator+(BranchProbability RHS) const {
    137     BranchProbability Prob(*this);
    138     Prob += RHS;
    139     return Prob;
    140   }
    141 
    142   BranchProbability operator-(BranchProbability RHS) const {
    143     BranchProbability Prob(*this);
    144     Prob -= RHS;
    145     return Prob;
    146   }
    147 
    148   BranchProbability operator*(BranchProbability RHS) const {
    149     BranchProbability Prob(*this);
    150     Prob *= RHS;
    151     return Prob;
    152   }
    153 
    154   BranchProbability operator*(uint32_t RHS) const {
    155     BranchProbability Prob(*this);
    156     Prob *= RHS;
    157     return Prob;
    158   }
    159 
    160   BranchProbability operator/(BranchProbability RHS) const {
    161     BranchProbability Prob(*this);
    162     Prob /= RHS;
    163     return Prob;
    164   }
    165 
    166   BranchProbability operator/(uint32_t RHS) const {
    167     BranchProbability Prob(*this);
    168     Prob /= RHS;
    169     return Prob;
    170   }
    171 
    172   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
    173   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
    174 
    175   bool operator<(BranchProbability RHS) const {
    176     assert(N != UnknownN && RHS.N != UnknownN &&
    177            "Unknown probability cannot participate in comparisons.");
    178     return N < RHS.N;
    179   }
    180 
    181   bool operator>(BranchProbability RHS) const {
    182     assert(N != UnknownN && RHS.N != UnknownN &&
    183            "Unknown probability cannot participate in comparisons.");
    184     return RHS < *this;
    185   }
    186 
    187   bool operator<=(BranchProbability RHS) const {
    188     assert(N != UnknownN && RHS.N != UnknownN &&
    189            "Unknown probability cannot participate in comparisons.");
    190     return !(RHS < *this);
    191   }
    192 
    193   bool operator>=(BranchProbability RHS) const {
    194     assert(N != UnknownN && RHS.N != UnknownN &&
    195            "Unknown probability cannot participate in comparisons.");
    196     return !(*this < RHS);
    197   }
    198 };
    199 
    200 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
    201   return Prob.print(OS);
    202 }
    203 
    204 template <class ProbabilityIter>
    205 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
    206                                                ProbabilityIter End) {
    207   if (Begin == End)
    208     return;
    209 
    210   unsigned UnknownProbCount = 0;
    211   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
    212                                  [&](uint64_t S, const BranchProbability &BP) {
    213                                    if (!BP.isUnknown())
    214                                      return S + BP.N;
    215                                    UnknownProbCount++;
    216                                    return S;
    217                                  });
    218 
    219   if (UnknownProbCount > 0) {
    220     BranchProbability ProbForUnknown = BranchProbability::getZero();
    221     // If the sum of all known probabilities is less than one, evenly distribute
    222     // the complement of sum to unknown probabilities. Otherwise, set unknown
    223     // probabilities to zeros and continue to normalize known probabilities.
    224     if (Sum < BranchProbability::getDenominator())
    225       ProbForUnknown = BranchProbability::getRaw(
    226           (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
    227 
    228     std::replace_if(Begin, End,
    229                     [](const BranchProbability &BP) { return BP.isUnknown(); },
    230                     ProbForUnknown);
    231 
    232     if (Sum <= BranchProbability::getDenominator())
    233       return;
    234   }
    235 
    236   if (Sum == 0) {
    237     BranchProbability BP(1, std::distance(Begin, End));
    238     std::fill(Begin, End, BP);
    239     return;
    240   }
    241 
    242   for (auto I = Begin; I != End; ++I)
    243     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
    244 }
    245 
    246 }
    247 
    248 #endif
    249