Home | History | Annotate | Line # | Download | only in AggressiveInstCombine
      1 //===- AggressiveInstCombine.cpp ------------------------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file implements the aggressive expression pattern combiner classes.
     10 // Currently, it handles expression patterns for:
     11 //  * Truncate instruction
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
     16 #include "AggressiveInstCombineInternal.h"
     17 #include "llvm-c/Initialization.h"
     18 #include "llvm-c/Transforms/AggressiveInstCombine.h"
     19 #include "llvm/ADT/Statistic.h"
     20 #include "llvm/Analysis/AliasAnalysis.h"
     21 #include "llvm/Analysis/BasicAliasAnalysis.h"
     22 #include "llvm/Analysis/GlobalsModRef.h"
     23 #include "llvm/Analysis/TargetLibraryInfo.h"
     24 #include "llvm/Analysis/ValueTracking.h"
     25 #include "llvm/IR/DataLayout.h"
     26 #include "llvm/IR/Dominators.h"
     27 #include "llvm/IR/Function.h"
     28 #include "llvm/IR/IRBuilder.h"
     29 #include "llvm/IR/LegacyPassManager.h"
     30 #include "llvm/IR/PatternMatch.h"
     31 #include "llvm/InitializePasses.h"
     32 #include "llvm/Pass.h"
     33 #include "llvm/Transforms/Utils/Local.h"
     34 
     35 using namespace llvm;
     36 using namespace PatternMatch;
     37 
     38 #define DEBUG_TYPE "aggressive-instcombine"
     39 
     40 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
     41 STATISTIC(NumGuardedRotates,
     42           "Number of guarded rotates transformed into funnel shifts");
     43 STATISTIC(NumGuardedFunnelShifts,
     44           "Number of guarded funnel shifts transformed into funnel shifts");
     45 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
     46 
     47 namespace {
     48 /// Contains expression pattern combiner logic.
     49 /// This class provides both the logic to combine expression patterns and
     50 /// combine them. It differs from InstCombiner class in that each pattern
     51 /// combiner runs only once as opposed to InstCombine's multi-iteration,
     52 /// which allows pattern combiner to have higher complexity than the O(1)
     53 /// required by the instruction combiner.
     54 class AggressiveInstCombinerLegacyPass : public FunctionPass {
     55 public:
     56   static char ID; // Pass identification, replacement for typeid
     57 
     58   AggressiveInstCombinerLegacyPass() : FunctionPass(ID) {
     59     initializeAggressiveInstCombinerLegacyPassPass(
     60         *PassRegistry::getPassRegistry());
     61   }
     62 
     63   void getAnalysisUsage(AnalysisUsage &AU) const override;
     64 
     65   /// Run all expression pattern optimizations on the given /p F function.
     66   ///
     67   /// \param F function to optimize.
     68   /// \returns true if the IR is changed.
     69   bool runOnFunction(Function &F) override;
     70 };
     71 } // namespace
     72 
     73 /// Match a pattern for a bitwise funnel/rotate operation that partially guards
     74 /// against undefined behavior by branching around the funnel-shift/rotation
     75 /// when the shift amount is 0.
     76 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
     77   if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
     78     return false;
     79 
     80   // As with the one-use checks below, this is not strictly necessary, but we
     81   // are being cautious to avoid potential perf regressions on targets that
     82   // do not actually have a funnel/rotate instruction (where the funnel shift
     83   // would be expanded back into math/shift/logic ops).
     84   if (!isPowerOf2_32(I.getType()->getScalarSizeInBits()))
     85     return false;
     86 
     87   // Match V to funnel shift left/right and capture the source operands and
     88   // shift amount.
     89   auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
     90                              Value *&ShAmt) {
     91     Value *SubAmt;
     92     unsigned Width = V->getType()->getScalarSizeInBits();
     93 
     94     // fshl(ShVal0, ShVal1, ShAmt)
     95     //  == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
     96     if (match(V, m_OneUse(m_c_Or(
     97                      m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
     98                      m_LShr(m_Value(ShVal1),
     99                             m_Sub(m_SpecificInt(Width), m_Value(SubAmt))))))) {
    100       if (ShAmt == SubAmt) // TODO: Use m_Specific
    101         return Intrinsic::fshl;
    102     }
    103 
    104     // fshr(ShVal0, ShVal1, ShAmt)
    105     //  == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
    106     if (match(V,
    107               m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
    108                                                            m_Value(SubAmt))),
    109                               m_LShr(m_Value(ShVal1), m_Value(ShAmt)))))) {
    110       if (ShAmt == SubAmt) // TODO: Use m_Specific
    111         return Intrinsic::fshr;
    112     }
    113 
    114     return Intrinsic::not_intrinsic;
    115   };
    116 
    117   // One phi operand must be a funnel/rotate operation, and the other phi
    118   // operand must be the source value of that funnel/rotate operation:
    119   // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
    120   // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
    121   // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
    122   PHINode &Phi = cast<PHINode>(I);
    123   unsigned FunnelOp = 0, GuardOp = 1;
    124   Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1);
    125   Value *ShVal0, *ShVal1, *ShAmt;
    126   Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
    127   if (IID == Intrinsic::not_intrinsic ||
    128       (IID == Intrinsic::fshl && ShVal0 != P1) ||
    129       (IID == Intrinsic::fshr && ShVal1 != P1)) {
    130     IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
    131     if (IID == Intrinsic::not_intrinsic ||
    132         (IID == Intrinsic::fshl && ShVal0 != P0) ||
    133         (IID == Intrinsic::fshr && ShVal1 != P0))
    134       return false;
    135     assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
    136            "Pattern must match funnel shift left or right");
    137     std::swap(FunnelOp, GuardOp);
    138   }
    139 
    140   // The incoming block with our source operand must be the "guard" block.
    141   // That must contain a cmp+branch to avoid the funnel/rotate when the shift
    142   // amount is equal to 0. The other incoming block is the block with the
    143   // funnel/rotate.
    144   BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp);
    145   BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp);
    146   Instruction *TermI = GuardBB->getTerminator();
    147 
    148   // Ensure that the shift values dominate each block.
    149   if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI))
    150     return false;
    151 
    152   ICmpInst::Predicate Pred;
    153   BasicBlock *PhiBB = Phi.getParent();
    154   if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()),
    155                          m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB))))
    156     return false;
    157 
    158   if (Pred != CmpInst::ICMP_EQ)
    159     return false;
    160 
    161   IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
    162 
    163   if (ShVal0 == ShVal1)
    164     ++NumGuardedRotates;
    165   else
    166     ++NumGuardedFunnelShifts;
    167 
    168   // If this is not a rotate then the select was blocking poison from the
    169   // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
    170   bool IsFshl = IID == Intrinsic::fshl;
    171   if (ShVal0 != ShVal1) {
    172     if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1))
    173       ShVal1 = Builder.CreateFreeze(ShVal1);
    174     else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0))
    175       ShVal0 = Builder.CreateFreeze(ShVal0);
    176   }
    177 
    178   // We matched a variation of this IR pattern:
    179   // GuardBB:
    180   //   %cmp = icmp eq i32 %ShAmt, 0
    181   //   br i1 %cmp, label %PhiBB, label %FunnelBB
    182   // FunnelBB:
    183   //   %sub = sub i32 32, %ShAmt
    184   //   %shr = lshr i32 %ShVal1, %sub
    185   //   %shl = shl i32 %ShVal0, %ShAmt
    186   //   %fsh = or i32 %shr, %shl
    187   //   br label %PhiBB
    188   // PhiBB:
    189   //   %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
    190   // -->
    191   // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
    192   Function *F = Intrinsic::getDeclaration(Phi.getModule(), IID, Phi.getType());
    193   Phi.replaceAllUsesWith(Builder.CreateCall(F, {ShVal0, ShVal1, ShAmt}));
    194   return true;
    195 }
    196 
    197 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
    198 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
    199 /// of 'and' ops, then we also need to capture the fact that we saw an
    200 /// "and X, 1", so that's an extra return value for that case.
    201 struct MaskOps {
    202   Value *Root;
    203   APInt Mask;
    204   bool MatchAndChain;
    205   bool FoundAnd1;
    206 
    207   MaskOps(unsigned BitWidth, bool MatchAnds)
    208       : Root(nullptr), Mask(APInt::getNullValue(BitWidth)),
    209         MatchAndChain(MatchAnds), FoundAnd1(false) {}
    210 };
    211 
    212 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
    213 /// chain of 'and' or 'or' instructions looking for shift ops of a common source
    214 /// value. Examples:
    215 ///   or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
    216 /// returns { X, 0x129 }
    217 ///   and (and (X >> 1), 1), (X >> 4)
    218 /// returns { X, 0x12 }
    219 static bool matchAndOrChain(Value *V, MaskOps &MOps) {
    220   Value *Op0, *Op1;
    221   if (MOps.MatchAndChain) {
    222     // Recurse through a chain of 'and' operands. This requires an extra check
    223     // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
    224     // in the chain to know that all of the high bits are cleared.
    225     if (match(V, m_And(m_Value(Op0), m_One()))) {
    226       MOps.FoundAnd1 = true;
    227       return matchAndOrChain(Op0, MOps);
    228     }
    229     if (match(V, m_And(m_Value(Op0), m_Value(Op1))))
    230       return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
    231   } else {
    232     // Recurse through a chain of 'or' operands.
    233     if (match(V, m_Or(m_Value(Op0), m_Value(Op1))))
    234       return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
    235   }
    236 
    237   // We need a shift-right or a bare value representing a compare of bit 0 of
    238   // the original source operand.
    239   Value *Candidate;
    240   const APInt *BitIndex = nullptr;
    241   if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex))))
    242     Candidate = V;
    243 
    244   // Initialize result source operand.
    245   if (!MOps.Root)
    246     MOps.Root = Candidate;
    247 
    248   // The shift constant is out-of-range? This code hasn't been simplified.
    249   if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth()))
    250     return false;
    251 
    252   // Fill in the mask bit derived from the shift constant.
    253   MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
    254   return MOps.Root == Candidate;
    255 }
    256 
    257 /// Match patterns that correspond to "any-bits-set" and "all-bits-set".
    258 /// These will include a chain of 'or' or 'and'-shifted bits from a
    259 /// common source value:
    260 /// and (or  (lshr X, C), ...), 1 --> (X & CMask) != 0
    261 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
    262 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
    263 /// that differ only with a final 'not' of the result. We expect that final
    264 /// 'not' to be folded with the compare that we create here (invert predicate).
    265 static bool foldAnyOrAllBitsSet(Instruction &I) {
    266   // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
    267   // final "and X, 1" instruction must be the final op in the sequence.
    268   bool MatchAllBitsSet;
    269   if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value())))
    270     MatchAllBitsSet = true;
    271   else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
    272     MatchAllBitsSet = false;
    273   else
    274     return false;
    275 
    276   MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet);
    277   if (MatchAllBitsSet) {
    278     if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1)
    279       return false;
    280   } else {
    281     if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps))
    282       return false;
    283   }
    284 
    285   // The pattern was found. Create a masked compare that replaces all of the
    286   // shift and logic ops.
    287   IRBuilder<> Builder(&I);
    288   Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask);
    289   Value *And = Builder.CreateAnd(MOps.Root, Mask);
    290   Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask)
    291                                : Builder.CreateIsNotNull(And);
    292   Value *Zext = Builder.CreateZExt(Cmp, I.getType());
    293   I.replaceAllUsesWith(Zext);
    294   ++NumAnyOrAllBitsSet;
    295   return true;
    296 }
    297 
    298 // Try to recognize below function as popcount intrinsic.
    299 // This is the "best" algorithm from
    300 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
    301 // Also used in TargetLowering::expandCTPOP().
    302 //
    303 // int popcount(unsigned int i) {
    304 //   i = i - ((i >> 1) & 0x55555555);
    305 //   i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
    306 //   i = ((i + (i >> 4)) & 0x0F0F0F0F);
    307 //   return (i * 0x01010101) >> 24;
    308 // }
    309 static bool tryToRecognizePopCount(Instruction &I) {
    310   if (I.getOpcode() != Instruction::LShr)
    311     return false;
    312 
    313   Type *Ty = I.getType();
    314   if (!Ty->isIntOrIntVectorTy())
    315     return false;
    316 
    317   unsigned Len = Ty->getScalarSizeInBits();
    318   // FIXME: fix Len == 8 and other irregular type lengths.
    319   if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
    320     return false;
    321 
    322   APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55));
    323   APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33));
    324   APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F));
    325   APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01));
    326   APInt MaskShift = APInt(Len, Len - 8);
    327 
    328   Value *Op0 = I.getOperand(0);
    329   Value *Op1 = I.getOperand(1);
    330   Value *MulOp0;
    331   // Matching "(i * 0x01010101...) >> 24".
    332   if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
    333        match(Op1, m_SpecificInt(MaskShift))) {
    334     Value *ShiftOp0;
    335     // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
    336     if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
    337                                     m_Deferred(ShiftOp0)),
    338                             m_SpecificInt(Mask0F)))) {
    339       Value *AndOp0;
    340       // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
    341       if (match(ShiftOp0,
    342                 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)),
    343                         m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)),
    344                               m_SpecificInt(Mask33))))) {
    345         Value *Root, *SubOp1;
    346         // Matching "i - ((i >> 1) & 0x55555555...)".
    347         if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) &&
    348             match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)),
    349                                 m_SpecificInt(Mask55)))) {
    350           LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
    351           IRBuilder<> Builder(&I);
    352           Function *Func = Intrinsic::getDeclaration(
    353               I.getModule(), Intrinsic::ctpop, I.getType());
    354           I.replaceAllUsesWith(Builder.CreateCall(Func, {Root}));
    355           ++NumPopCountRecognized;
    356           return true;
    357         }
    358       }
    359     }
    360   }
    361 
    362   return false;
    363 }
    364 
    365 /// This is the entry point for folds that could be implemented in regular
    366 /// InstCombine, but they are separated because they are not expected to
    367 /// occur frequently and/or have more than a constant-length pattern match.
    368 static bool foldUnusualPatterns(Function &F, DominatorTree &DT) {
    369   bool MadeChange = false;
    370   for (BasicBlock &BB : F) {
    371     // Ignore unreachable basic blocks.
    372     if (!DT.isReachableFromEntry(&BB))
    373       continue;
    374     // Do not delete instructions under here and invalidate the iterator.
    375     // Walk the block backwards for efficiency. We're matching a chain of
    376     // use->defs, so we're more likely to succeed by starting from the bottom.
    377     // Also, we want to avoid matching partial patterns.
    378     // TODO: It would be more efficient if we removed dead instructions
    379     // iteratively in this loop rather than waiting until the end.
    380     for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
    381       MadeChange |= foldAnyOrAllBitsSet(I);
    382       MadeChange |= foldGuardedFunnelShift(I, DT);
    383       MadeChange |= tryToRecognizePopCount(I);
    384     }
    385   }
    386 
    387   // We're done with transforms, so remove dead instructions.
    388   if (MadeChange)
    389     for (BasicBlock &BB : F)
    390       SimplifyInstructionsInBlock(&BB);
    391 
    392   return MadeChange;
    393 }
    394 
    395 /// This is the entry point for all transforms. Pass manager differences are
    396 /// handled in the callers of this function.
    397 static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) {
    398   bool MadeChange = false;
    399   const DataLayout &DL = F.getParent()->getDataLayout();
    400   TruncInstCombine TIC(TLI, DL, DT);
    401   MadeChange |= TIC.run(F);
    402   MadeChange |= foldUnusualPatterns(F, DT);
    403   return MadeChange;
    404 }
    405 
    406 void AggressiveInstCombinerLegacyPass::getAnalysisUsage(
    407     AnalysisUsage &AU) const {
    408   AU.setPreservesCFG();
    409   AU.addRequired<DominatorTreeWrapperPass>();
    410   AU.addRequired<TargetLibraryInfoWrapperPass>();
    411   AU.addPreserved<AAResultsWrapperPass>();
    412   AU.addPreserved<BasicAAWrapperPass>();
    413   AU.addPreserved<DominatorTreeWrapperPass>();
    414   AU.addPreserved<GlobalsAAWrapperPass>();
    415 }
    416 
    417 bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) {
    418   auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
    419   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    420   return runImpl(F, TLI, DT);
    421 }
    422 
    423 PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
    424                                                  FunctionAnalysisManager &AM) {
    425   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
    426   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
    427   if (!runImpl(F, TLI, DT)) {
    428     // No changes, all analyses are preserved.
    429     return PreservedAnalyses::all();
    430   }
    431   // Mark all the analyses that instcombine updates as preserved.
    432   PreservedAnalyses PA;
    433   PA.preserveSet<CFGAnalyses>();
    434   return PA;
    435 }
    436 
    437 char AggressiveInstCombinerLegacyPass::ID = 0;
    438 INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass,
    439                       "aggressive-instcombine",
    440                       "Combine pattern based expressions", false, false)
    441 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
    442 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
    443 INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine",
    444                     "Combine pattern based expressions", false, false)
    445 
    446 // Initialization Routines
    447 void llvm::initializeAggressiveInstCombine(PassRegistry &Registry) {
    448   initializeAggressiveInstCombinerLegacyPassPass(Registry);
    449 }
    450 
    451 void LLVMInitializeAggressiveInstCombiner(LLVMPassRegistryRef R) {
    452   initializeAggressiveInstCombinerLegacyPassPass(*unwrap(R));
    453 }
    454 
    455 FunctionPass *llvm::createAggressiveInstCombinerPass() {
    456   return new AggressiveInstCombinerLegacyPass();
    457 }
    458 
    459 void LLVMAddAggressiveInstCombinerPass(LLVMPassManagerRef PM) {
    460   unwrap(PM)->add(createAggressiveInstCombinerPass());
    461 }
    462