Home | History | Annotate | Line # | Download | only in CodeGen
      1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file contains switch inst lowering optimizations and utilities for
     10 // codegen, so that it can be used for both SelectionDAG and GlobalISel.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "llvm/CodeGen/SwitchLoweringUtils.h"
     15 #include "llvm/CodeGen/FunctionLoweringInfo.h"
     16 #include "llvm/CodeGen/MachineJumpTableInfo.h"
     17 #include "llvm/CodeGen/TargetLowering.h"
     18 #include "llvm/Target/TargetMachine.h"
     19 
     20 using namespace llvm;
     21 using namespace SwitchCG;
     22 
     23 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
     24                                      unsigned First, unsigned Last) {
     25   assert(Last >= First);
     26   const APInt &LowCase = Clusters[First].Low->getValue();
     27   const APInt &HighCase = Clusters[Last].High->getValue();
     28   assert(LowCase.getBitWidth() == HighCase.getBitWidth());
     29 
     30   // FIXME: A range of consecutive cases has 100% density, but only requires one
     31   // comparison to lower. We should discriminate against such consecutive ranges
     32   // in jump tables.
     33   return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
     34 }
     35 
     36 uint64_t
     37 SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
     38                                unsigned First, unsigned Last) {
     39   assert(Last >= First);
     40   assert(TotalCases[Last] >= TotalCases[First]);
     41   uint64_t NumCases =
     42       TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
     43   return NumCases;
     44 }
     45 
     46 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
     47                                               const SwitchInst *SI,
     48                                               MachineBasicBlock *DefaultMBB,
     49                                               ProfileSummaryInfo *PSI,
     50                                               BlockFrequencyInfo *BFI) {
     51 #ifndef NDEBUG
     52   // Clusters must be non-empty, sorted, and only contain Range clusters.
     53   assert(!Clusters.empty());
     54   for (CaseCluster &C : Clusters)
     55     assert(C.Kind == CC_Range);
     56   for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
     57     assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
     58 #endif
     59 
     60   assert(TLI && "TLI not set!");
     61   if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
     62     return;
     63 
     64   const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
     65   const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
     66 
     67   // Bail if not enough cases.
     68   const int64_t N = Clusters.size();
     69   if (N < 2 || N < MinJumpTableEntries)
     70     return;
     71 
     72   // Accumulated number of cases in each cluster and those prior to it.
     73   SmallVector<unsigned, 8> TotalCases(N);
     74   for (unsigned i = 0; i < N; ++i) {
     75     const APInt &Hi = Clusters[i].High->getValue();
     76     const APInt &Lo = Clusters[i].Low->getValue();
     77     TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
     78     if (i != 0)
     79       TotalCases[i] += TotalCases[i - 1];
     80   }
     81 
     82   uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
     83   uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
     84   assert(NumCases < UINT64_MAX / 100);
     85   assert(Range >= NumCases);
     86 
     87   // Cheap case: the whole range may be suitable for jump table.
     88   if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
     89     CaseCluster JTCluster;
     90     if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
     91       Clusters[0] = JTCluster;
     92       Clusters.resize(1);
     93       return;
     94     }
     95   }
     96 
     97   // The algorithm below is not suitable for -O0.
     98   if (TM->getOptLevel() == CodeGenOpt::None)
     99     return;
    100 
    101   // Split Clusters into minimum number of dense partitions. The algorithm uses
    102   // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
    103   // for the Case Statement'" (1994), but builds the MinPartitions array in
    104   // reverse order to make it easier to reconstruct the partitions in ascending
    105   // order. In the choice between two optimal partitionings, it picks the one
    106   // which yields more jump tables.
    107 
    108   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
    109   SmallVector<unsigned, 8> MinPartitions(N);
    110   // LastElement[i] is the last element of the partition starting at i.
    111   SmallVector<unsigned, 8> LastElement(N);
    112   // PartitionsScore[i] is used to break ties when choosing between two
    113   // partitionings resulting in the same number of partitions.
    114   SmallVector<unsigned, 8> PartitionsScore(N);
    115   // For PartitionsScore, a small number of comparisons is considered as good as
    116   // a jump table and a single comparison is considered better than a jump
    117   // table.
    118   enum PartitionScores : unsigned {
    119     NoTable = 0,
    120     Table = 1,
    121     FewCases = 1,
    122     SingleCase = 2
    123   };
    124 
    125   // Base case: There is only one way to partition Clusters[N-1].
    126   MinPartitions[N - 1] = 1;
    127   LastElement[N - 1] = N - 1;
    128   PartitionsScore[N - 1] = PartitionScores::SingleCase;
    129 
    130   // Note: loop indexes are signed to avoid underflow.
    131   for (int64_t i = N - 2; i >= 0; i--) {
    132     // Find optimal partitioning of Clusters[i..N-1].
    133     // Baseline: Put Clusters[i] into a partition on its own.
    134     MinPartitions[i] = MinPartitions[i + 1] + 1;
    135     LastElement[i] = i;
    136     PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
    137 
    138     // Search for a solution that results in fewer partitions.
    139     for (int64_t j = N - 1; j > i; j--) {
    140       // Try building a partition from Clusters[i..j].
    141       Range = getJumpTableRange(Clusters, i, j);
    142       NumCases = getJumpTableNumCases(TotalCases, i, j);
    143       assert(NumCases < UINT64_MAX / 100);
    144       assert(Range >= NumCases);
    145 
    146       if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
    147         unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
    148         unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
    149         int64_t NumEntries = j - i + 1;
    150 
    151         if (NumEntries == 1)
    152           Score += PartitionScores::SingleCase;
    153         else if (NumEntries <= SmallNumberOfEntries)
    154           Score += PartitionScores::FewCases;
    155         else if (NumEntries >= MinJumpTableEntries)
    156           Score += PartitionScores::Table;
    157 
    158         // If this leads to fewer partitions, or to the same number of
    159         // partitions with better score, it is a better partitioning.
    160         if (NumPartitions < MinPartitions[i] ||
    161             (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
    162           MinPartitions[i] = NumPartitions;
    163           LastElement[i] = j;
    164           PartitionsScore[i] = Score;
    165         }
    166       }
    167     }
    168   }
    169 
    170   // Iterate over the partitions, replacing some with jump tables in-place.
    171   unsigned DstIndex = 0;
    172   for (unsigned First = 0, Last; First < N; First = Last + 1) {
    173     Last = LastElement[First];
    174     assert(Last >= First);
    175     assert(DstIndex <= First);
    176     unsigned NumClusters = Last - First + 1;
    177 
    178     CaseCluster JTCluster;
    179     if (NumClusters >= MinJumpTableEntries &&
    180         buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
    181       Clusters[DstIndex++] = JTCluster;
    182     } else {
    183       for (unsigned I = First; I <= Last; ++I)
    184         std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
    185     }
    186   }
    187   Clusters.resize(DstIndex);
    188 }
    189 
    190 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
    191                                               unsigned First, unsigned Last,
    192                                               const SwitchInst *SI,
    193                                               MachineBasicBlock *DefaultMBB,
    194                                               CaseCluster &JTCluster) {
    195   assert(First <= Last);
    196 
    197   auto Prob = BranchProbability::getZero();
    198   unsigned NumCmps = 0;
    199   std::vector<MachineBasicBlock*> Table;
    200   DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
    201 
    202   // Initialize probabilities in JTProbs.
    203   for (unsigned I = First; I <= Last; ++I)
    204     JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
    205 
    206   for (unsigned I = First; I <= Last; ++I) {
    207     assert(Clusters[I].Kind == CC_Range);
    208     Prob += Clusters[I].Prob;
    209     const APInt &Low = Clusters[I].Low->getValue();
    210     const APInt &High = Clusters[I].High->getValue();
    211     NumCmps += (Low == High) ? 1 : 2;
    212     if (I != First) {
    213       // Fill the gap between this and the previous cluster.
    214       const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
    215       assert(PreviousHigh.slt(Low));
    216       uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
    217       for (uint64_t J = 0; J < Gap; J++)
    218         Table.push_back(DefaultMBB);
    219     }
    220     uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
    221     for (uint64_t J = 0; J < ClusterSize; ++J)
    222       Table.push_back(Clusters[I].MBB);
    223     JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
    224   }
    225 
    226   unsigned NumDests = JTProbs.size();
    227   if (TLI->isSuitableForBitTests(NumDests, NumCmps,
    228                                  Clusters[First].Low->getValue(),
    229                                  Clusters[Last].High->getValue(), *DL)) {
    230     // Clusters[First..Last] should be lowered as bit tests instead.
    231     return false;
    232   }
    233 
    234   // Create the MBB that will load from and jump through the table.
    235   // Note: We create it here, but it's not inserted into the function yet.
    236   MachineFunction *CurMF = FuncInfo.MF;
    237   MachineBasicBlock *JumpTableMBB =
    238       CurMF->CreateMachineBasicBlock(SI->getParent());
    239 
    240   // Add successors. Note: use table order for determinism.
    241   SmallPtrSet<MachineBasicBlock *, 8> Done;
    242   for (MachineBasicBlock *Succ : Table) {
    243     if (Done.count(Succ))
    244       continue;
    245     addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
    246     Done.insert(Succ);
    247   }
    248   JumpTableMBB->normalizeSuccProbs();
    249 
    250   unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
    251                      ->createJumpTableIndex(Table);
    252 
    253   // Set up the jump table info.
    254   JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
    255   JumpTableHeader JTH(Clusters[First].Low->getValue(),
    256                       Clusters[Last].High->getValue(), SI->getCondition(),
    257                       nullptr, false);
    258   JTCases.emplace_back(std::move(JTH), std::move(JT));
    259 
    260   JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
    261                                      JTCases.size() - 1, Prob);
    262   return true;
    263 }
    264 
    265 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
    266                                                    const SwitchInst *SI) {
    267   // Partition Clusters into as few subsets as possible, where each subset has a
    268   // range that fits in a machine word and has <= 3 unique destinations.
    269 
    270 #ifndef NDEBUG
    271   // Clusters must be sorted and contain Range or JumpTable clusters.
    272   assert(!Clusters.empty());
    273   assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
    274   for (const CaseCluster &C : Clusters)
    275     assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
    276   for (unsigned i = 1; i < Clusters.size(); ++i)
    277     assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
    278 #endif
    279 
    280   // The algorithm below is not suitable for -O0.
    281   if (TM->getOptLevel() == CodeGenOpt::None)
    282     return;
    283 
    284   // If target does not have legal shift left, do not emit bit tests at all.
    285   EVT PTy = TLI->getPointerTy(*DL);
    286   if (!TLI->isOperationLegal(ISD::SHL, PTy))
    287     return;
    288 
    289   int BitWidth = PTy.getSizeInBits();
    290   const int64_t N = Clusters.size();
    291 
    292   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
    293   SmallVector<unsigned, 8> MinPartitions(N);
    294   // LastElement[i] is the last element of the partition starting at i.
    295   SmallVector<unsigned, 8> LastElement(N);
    296 
    297   // FIXME: This might not be the best algorithm for finding bit test clusters.
    298 
    299   // Base case: There is only one way to partition Clusters[N-1].
    300   MinPartitions[N - 1] = 1;
    301   LastElement[N - 1] = N - 1;
    302 
    303   // Note: loop indexes are signed to avoid underflow.
    304   for (int64_t i = N - 2; i >= 0; --i) {
    305     // Find optimal partitioning of Clusters[i..N-1].
    306     // Baseline: Put Clusters[i] into a partition on its own.
    307     MinPartitions[i] = MinPartitions[i + 1] + 1;
    308     LastElement[i] = i;
    309 
    310     // Search for a solution that results in fewer partitions.
    311     // Note: the search is limited by BitWidth, reducing time complexity.
    312     for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
    313       // Try building a partition from Clusters[i..j].
    314 
    315       // Check the range.
    316       if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
    317                                 Clusters[j].High->getValue(), *DL))
    318         continue;
    319 
    320       // Check nbr of destinations and cluster types.
    321       // FIXME: This works, but doesn't seem very efficient.
    322       bool RangesOnly = true;
    323       BitVector Dests(FuncInfo.MF->getNumBlockIDs());
    324       for (int64_t k = i; k <= j; k++) {
    325         if (Clusters[k].Kind != CC_Range) {
    326           RangesOnly = false;
    327           break;
    328         }
    329         Dests.set(Clusters[k].MBB->getNumber());
    330       }
    331       if (!RangesOnly || Dests.count() > 3)
    332         break;
    333 
    334       // Check if it's a better partition.
    335       unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
    336       if (NumPartitions < MinPartitions[i]) {
    337         // Found a better partition.
    338         MinPartitions[i] = NumPartitions;
    339         LastElement[i] = j;
    340       }
    341     }
    342   }
    343 
    344   // Iterate over the partitions, replacing with bit-test clusters in-place.
    345   unsigned DstIndex = 0;
    346   for (unsigned First = 0, Last; First < N; First = Last + 1) {
    347     Last = LastElement[First];
    348     assert(First <= Last);
    349     assert(DstIndex <= First);
    350 
    351     CaseCluster BitTestCluster;
    352     if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
    353       Clusters[DstIndex++] = BitTestCluster;
    354     } else {
    355       size_t NumClusters = Last - First + 1;
    356       std::memmove(&Clusters[DstIndex], &Clusters[First],
    357                    sizeof(Clusters[0]) * NumClusters);
    358       DstIndex += NumClusters;
    359     }
    360   }
    361   Clusters.resize(DstIndex);
    362 }
    363 
    364 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
    365                                              unsigned First, unsigned Last,
    366                                              const SwitchInst *SI,
    367                                              CaseCluster &BTCluster) {
    368   assert(First <= Last);
    369   if (First == Last)
    370     return false;
    371 
    372   BitVector Dests(FuncInfo.MF->getNumBlockIDs());
    373   unsigned NumCmps = 0;
    374   for (int64_t I = First; I <= Last; ++I) {
    375     assert(Clusters[I].Kind == CC_Range);
    376     Dests.set(Clusters[I].MBB->getNumber());
    377     NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
    378   }
    379   unsigned NumDests = Dests.count();
    380 
    381   APInt Low = Clusters[First].Low->getValue();
    382   APInt High = Clusters[Last].High->getValue();
    383   assert(Low.slt(High));
    384 
    385   if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
    386     return false;
    387 
    388   APInt LowBound;
    389   APInt CmpRange;
    390 
    391   const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
    392   assert(TLI->rangeFitsInWord(Low, High, *DL) &&
    393          "Case range must fit in bit mask!");
    394 
    395   // Check if the clusters cover a contiguous range such that no value in the
    396   // range will jump to the default statement.
    397   bool ContiguousRange = true;
    398   for (int64_t I = First + 1; I <= Last; ++I) {
    399     if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
    400       ContiguousRange = false;
    401       break;
    402     }
    403   }
    404 
    405   if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
    406     // Optimize the case where all the case values fit in a word without having
    407     // to subtract minValue. In this case, we can optimize away the subtraction.
    408     LowBound = APInt::getNullValue(Low.getBitWidth());
    409     CmpRange = High;
    410     ContiguousRange = false;
    411   } else {
    412     LowBound = Low;
    413     CmpRange = High - Low;
    414   }
    415 
    416   CaseBitsVector CBV;
    417   auto TotalProb = BranchProbability::getZero();
    418   for (unsigned i = First; i <= Last; ++i) {
    419     // Find the CaseBits for this destination.
    420     unsigned j;
    421     for (j = 0; j < CBV.size(); ++j)
    422       if (CBV[j].BB == Clusters[i].MBB)
    423         break;
    424     if (j == CBV.size())
    425       CBV.push_back(
    426           CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
    427     CaseBits *CB = &CBV[j];
    428 
    429     // Update Mask, Bits and ExtraProb.
    430     uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
    431     uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
    432     assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
    433     CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
    434     CB->Bits += Hi - Lo + 1;
    435     CB->ExtraProb += Clusters[i].Prob;
    436     TotalProb += Clusters[i].Prob;
    437   }
    438 
    439   BitTestInfo BTI;
    440   llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
    441     // Sort by probability first, number of bits second, bit mask third.
    442     if (a.ExtraProb != b.ExtraProb)
    443       return a.ExtraProb > b.ExtraProb;
    444     if (a.Bits != b.Bits)
    445       return a.Bits > b.Bits;
    446     return a.Mask < b.Mask;
    447   });
    448 
    449   for (auto &CB : CBV) {
    450     MachineBasicBlock *BitTestBB =
    451         FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
    452     BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
    453   }
    454   BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
    455                             SI->getCondition(), -1U, MVT::Other, false,
    456                             ContiguousRange, nullptr, nullptr, std::move(BTI),
    457                             TotalProb);
    458 
    459   BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
    460                                     BitTestCases.size() - 1, TotalProb);
    461   return true;
    462 }
    463 
    464 void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
    465 #ifndef NDEBUG
    466   for (const CaseCluster &CC : Clusters)
    467     assert(CC.Low == CC.High && "Input clusters must be single-case");
    468 #endif
    469 
    470   llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
    471     return a.Low->getValue().slt(b.Low->getValue());
    472   });
    473 
    474   // Merge adjacent clusters with the same destination.
    475   const unsigned N = Clusters.size();
    476   unsigned DstIndex = 0;
    477   for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
    478     CaseCluster &CC = Clusters[SrcIndex];
    479     const ConstantInt *CaseVal = CC.Low;
    480     MachineBasicBlock *Succ = CC.MBB;
    481 
    482     if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
    483         (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
    484       // If this case has the same successor and is a neighbour, merge it into
    485       // the previous cluster.
    486       Clusters[DstIndex - 1].High = CaseVal;
    487       Clusters[DstIndex - 1].Prob += CC.Prob;
    488     } else {
    489       std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
    490                    sizeof(Clusters[SrcIndex]));
    491     }
    492   }
    493   Clusters.resize(DstIndex);
    494 }
    495