Home | History | Annotate | Line # | Download | only in CodeGen
      1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 
      9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
     10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
     11 
     12 #include "llvm/ADT/SmallVector.h"
     13 #include "llvm/CodeGen/ISDOpcodes.h"
     14 #include "llvm/CodeGen/SelectionDAGNodes.h"
     15 #include "llvm/IR/InstrTypes.h"
     16 #include "llvm/Support/BranchProbability.h"
     17 #include <vector>
     18 
     19 namespace llvm {
     20 
     21 class BlockFrequencyInfo;
     22 class ConstantInt;
     23 class FunctionLoweringInfo;
     24 class MachineBasicBlock;
     25 class ProfileSummaryInfo;
     26 class TargetLowering;
     27 class TargetMachine;
     28 
     29 namespace SwitchCG {
     30 
     31 enum CaseClusterKind {
     32   /// A cluster of adjacent case labels with the same destination, or just one
     33   /// case.
     34   CC_Range,
     35   /// A cluster of cases suitable for jump table lowering.
     36   CC_JumpTable,
     37   /// A cluster of cases suitable for bit test lowering.
     38   CC_BitTests
     39 };
     40 
     41 /// A cluster of case labels.
     42 struct CaseCluster {
     43   CaseClusterKind Kind;
     44   const ConstantInt *Low, *High;
     45   union {
     46     MachineBasicBlock *MBB;
     47     unsigned JTCasesIndex;
     48     unsigned BTCasesIndex;
     49   };
     50   BranchProbability Prob;
     51 
     52   static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
     53                            MachineBasicBlock *MBB, BranchProbability Prob) {
     54     CaseCluster C;
     55     C.Kind = CC_Range;
     56     C.Low = Low;
     57     C.High = High;
     58     C.MBB = MBB;
     59     C.Prob = Prob;
     60     return C;
     61   }
     62 
     63   static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
     64                                unsigned JTCasesIndex, BranchProbability Prob) {
     65     CaseCluster C;
     66     C.Kind = CC_JumpTable;
     67     C.Low = Low;
     68     C.High = High;
     69     C.JTCasesIndex = JTCasesIndex;
     70     C.Prob = Prob;
     71     return C;
     72   }
     73 
     74   static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
     75                               unsigned BTCasesIndex, BranchProbability Prob) {
     76     CaseCluster C;
     77     C.Kind = CC_BitTests;
     78     C.Low = Low;
     79     C.High = High;
     80     C.BTCasesIndex = BTCasesIndex;
     81     C.Prob = Prob;
     82     return C;
     83   }
     84 };
     85 
     86 using CaseClusterVector = std::vector<CaseCluster>;
     87 using CaseClusterIt = CaseClusterVector::iterator;
     88 
     89 /// Sort Clusters and merge adjacent cases.
     90 void sortAndRangeify(CaseClusterVector &Clusters);
     91 
     92 struct CaseBits {
     93   uint64_t Mask = 0;
     94   MachineBasicBlock *BB = nullptr;
     95   unsigned Bits = 0;
     96   BranchProbability ExtraProb;
     97 
     98   CaseBits() = default;
     99   CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
    100            BranchProbability Prob)
    101       : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
    102 };
    103 
    104 using CaseBitsVector = std::vector<CaseBits>;
    105 
    106 /// This structure is used to communicate between SelectionDAGBuilder and
    107 /// SDISel for the code generation of additional basic blocks needed by
    108 /// multi-case switch statements.
    109 struct CaseBlock {
    110   // For the GISel interface.
    111   struct PredInfoPair {
    112     CmpInst::Predicate Pred;
    113     // Set when no comparison should be emitted.
    114     bool NoCmp;
    115   };
    116   union {
    117     // The condition code to use for the case block's setcc node.
    118     // Besides the integer condition codes, this can also be SETTRUE, in which
    119     // case no comparison gets emitted.
    120     ISD::CondCode CC;
    121     struct PredInfoPair PredInfo;
    122   };
    123 
    124   // The LHS/MHS/RHS of the comparison to emit.
    125   // Emit by default LHS op RHS. MHS is used for range comparisons:
    126   // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
    127   const Value *CmpLHS, *CmpMHS, *CmpRHS;
    128 
    129   // The block to branch to if the setcc is true/false.
    130   MachineBasicBlock *TrueBB, *FalseBB;
    131 
    132   // The block into which to emit the code for the setcc and branches.
    133   MachineBasicBlock *ThisBB;
    134 
    135   /// The debug location of the instruction this CaseBlock was
    136   /// produced from.
    137   SDLoc DL;
    138   DebugLoc DbgLoc;
    139 
    140   // Branch weights.
    141   BranchProbability TrueProb, FalseProb;
    142 
    143   // Constructor for SelectionDAG.
    144   CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
    145             const Value *cmpmiddle, MachineBasicBlock *truebb,
    146             MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
    147             BranchProbability trueprob = BranchProbability::getUnknown(),
    148             BranchProbability falseprob = BranchProbability::getUnknown())
    149       : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
    150         TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
    151         TrueProb(trueprob), FalseProb(falseprob) {}
    152 
    153   // Constructor for GISel.
    154   CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
    155             const Value *cmprhs, const Value *cmpmiddle,
    156             MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
    157             MachineBasicBlock *me, DebugLoc dl,
    158             BranchProbability trueprob = BranchProbability::getUnknown(),
    159             BranchProbability falseprob = BranchProbability::getUnknown())
    160       : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
    161         CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
    162         DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
    163 };
    164 
    165 struct JumpTable {
    166   /// The virtual register containing the index of the jump table entry
    167   /// to jump to.
    168   unsigned Reg;
    169   /// The JumpTableIndex for this jump table in the function.
    170   unsigned JTI;
    171   /// The MBB into which to emit the code for the indirect jump.
    172   MachineBasicBlock *MBB;
    173   /// The MBB of the default bb, which is a successor of the range
    174   /// check MBB.  This is when updating PHI nodes in successors.
    175   MachineBasicBlock *Default;
    176 
    177   JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
    178       : Reg(R), JTI(J), MBB(M), Default(D) {}
    179 };
    180 struct JumpTableHeader {
    181   APInt First;
    182   APInt Last;
    183   const Value *SValue;
    184   MachineBasicBlock *HeaderBB;
    185   bool Emitted;
    186   bool OmitRangeCheck;
    187 
    188   JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
    189                   bool E = false)
    190       : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
    191         Emitted(E), OmitRangeCheck(false) {}
    192 };
    193 using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
    194 
    195 struct BitTestCase {
    196   uint64_t Mask;
    197   MachineBasicBlock *ThisBB;
    198   MachineBasicBlock *TargetBB;
    199   BranchProbability ExtraProb;
    200 
    201   BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
    202               BranchProbability Prob)
    203       : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
    204 };
    205 
    206 using BitTestInfo = SmallVector<BitTestCase, 3>;
    207 
    208 struct BitTestBlock {
    209   APInt First;
    210   APInt Range;
    211   const Value *SValue;
    212   unsigned Reg;
    213   MVT RegVT;
    214   bool Emitted;
    215   bool ContiguousRange;
    216   MachineBasicBlock *Parent;
    217   MachineBasicBlock *Default;
    218   BitTestInfo Cases;
    219   BranchProbability Prob;
    220   BranchProbability DefaultProb;
    221   bool OmitRangeCheck;
    222 
    223   BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
    224                bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
    225                BitTestInfo C, BranchProbability Pr)
    226       : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
    227         RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
    228         Cases(std::move(C)), Prob(Pr), OmitRangeCheck(false) {}
    229 };
    230 
    231 /// Return the range of values within a range.
    232 uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
    233                            unsigned Last);
    234 
    235 /// Return the number of cases within a range.
    236 uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
    237                               unsigned First, unsigned Last);
    238 
    239 struct SwitchWorkListItem {
    240   MachineBasicBlock *MBB;
    241   CaseClusterIt FirstCluster;
    242   CaseClusterIt LastCluster;
    243   const ConstantInt *GE;
    244   const ConstantInt *LT;
    245   BranchProbability DefaultProb;
    246 };
    247 using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
    248 
    249 class SwitchLowering {
    250 public:
    251   SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
    252 
    253   void init(const TargetLowering &tli, const TargetMachine &tm,
    254             const DataLayout &dl) {
    255     TLI = &tli;
    256     TM = &tm;
    257     DL = &dl;
    258   }
    259 
    260   /// Vector of CaseBlock structures used to communicate SwitchInst code
    261   /// generation information.
    262   std::vector<CaseBlock> SwitchCases;
    263 
    264   /// Vector of JumpTable structures used to communicate SwitchInst code
    265   /// generation information.
    266   std::vector<JumpTableBlock> JTCases;
    267 
    268   /// Vector of BitTestBlock structures used to communicate SwitchInst code
    269   /// generation information.
    270   std::vector<BitTestBlock> BitTestCases;
    271 
    272   void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
    273                       MachineBasicBlock *DefaultMBB,
    274                       ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
    275 
    276   bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
    277                       unsigned Last, const SwitchInst *SI,
    278                       MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
    279 
    280 
    281   void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
    282 
    283   /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
    284   /// decides it's not a good idea.
    285   bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
    286                      const SwitchInst *SI, CaseCluster &BTCluster);
    287 
    288   virtual void addSuccessorWithProb(
    289       MachineBasicBlock *Src, MachineBasicBlock *Dst,
    290       BranchProbability Prob = BranchProbability::getUnknown()) = 0;
    291 
    292   virtual ~SwitchLowering() = default;
    293 
    294 private:
    295   const TargetLowering *TLI;
    296   const TargetMachine *TM;
    297   const DataLayout *DL;
    298   FunctionLoweringInfo &FuncInfo;
    299 };
    300 
    301 } // namespace SwitchCG
    302 } // namespace llvm
    303 
    304 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
    305