Home | History | Annotate | Line # | Download | only in CodeGen
      1 //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This pass implements IR expansion for reduction intrinsics, allowing targets
     10 // to enable the intrinsics until just before codegen.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "llvm/CodeGen/ExpandReductions.h"
     15 #include "llvm/Analysis/TargetTransformInfo.h"
     16 #include "llvm/CodeGen/Passes.h"
     17 #include "llvm/IR/Function.h"
     18 #include "llvm/IR/IRBuilder.h"
     19 #include "llvm/IR/InstIterator.h"
     20 #include "llvm/IR/IntrinsicInst.h"
     21 #include "llvm/IR/Intrinsics.h"
     22 #include "llvm/IR/Module.h"
     23 #include "llvm/InitializePasses.h"
     24 #include "llvm/Pass.h"
     25 #include "llvm/Transforms/Utils/LoopUtils.h"
     26 
     27 using namespace llvm;
     28 
     29 namespace {
     30 
     31 unsigned getOpcode(Intrinsic::ID ID) {
     32   switch (ID) {
     33   case Intrinsic::vector_reduce_fadd:
     34     return Instruction::FAdd;
     35   case Intrinsic::vector_reduce_fmul:
     36     return Instruction::FMul;
     37   case Intrinsic::vector_reduce_add:
     38     return Instruction::Add;
     39   case Intrinsic::vector_reduce_mul:
     40     return Instruction::Mul;
     41   case Intrinsic::vector_reduce_and:
     42     return Instruction::And;
     43   case Intrinsic::vector_reduce_or:
     44     return Instruction::Or;
     45   case Intrinsic::vector_reduce_xor:
     46     return Instruction::Xor;
     47   case Intrinsic::vector_reduce_smax:
     48   case Intrinsic::vector_reduce_smin:
     49   case Intrinsic::vector_reduce_umax:
     50   case Intrinsic::vector_reduce_umin:
     51     return Instruction::ICmp;
     52   case Intrinsic::vector_reduce_fmax:
     53   case Intrinsic::vector_reduce_fmin:
     54     return Instruction::FCmp;
     55   default:
     56     llvm_unreachable("Unexpected ID");
     57   }
     58 }
     59 
     60 RecurKind getRK(Intrinsic::ID ID) {
     61   switch (ID) {
     62   case Intrinsic::vector_reduce_smax:
     63     return RecurKind::SMax;
     64   case Intrinsic::vector_reduce_smin:
     65     return RecurKind::SMin;
     66   case Intrinsic::vector_reduce_umax:
     67     return RecurKind::UMax;
     68   case Intrinsic::vector_reduce_umin:
     69     return RecurKind::UMin;
     70   case Intrinsic::vector_reduce_fmax:
     71     return RecurKind::FMax;
     72   case Intrinsic::vector_reduce_fmin:
     73     return RecurKind::FMin;
     74   default:
     75     return RecurKind::None;
     76   }
     77 }
     78 
     79 bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
     80   bool Changed = false;
     81   SmallVector<IntrinsicInst *, 4> Worklist;
     82   for (auto &I : instructions(F)) {
     83     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
     84       switch (II->getIntrinsicID()) {
     85       default: break;
     86       case Intrinsic::vector_reduce_fadd:
     87       case Intrinsic::vector_reduce_fmul:
     88       case Intrinsic::vector_reduce_add:
     89       case Intrinsic::vector_reduce_mul:
     90       case Intrinsic::vector_reduce_and:
     91       case Intrinsic::vector_reduce_or:
     92       case Intrinsic::vector_reduce_xor:
     93       case Intrinsic::vector_reduce_smax:
     94       case Intrinsic::vector_reduce_smin:
     95       case Intrinsic::vector_reduce_umax:
     96       case Intrinsic::vector_reduce_umin:
     97       case Intrinsic::vector_reduce_fmax:
     98       case Intrinsic::vector_reduce_fmin:
     99         if (TTI->shouldExpandReduction(II))
    100           Worklist.push_back(II);
    101 
    102         break;
    103       }
    104     }
    105   }
    106 
    107   for (auto *II : Worklist) {
    108     FastMathFlags FMF =
    109         isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
    110     Intrinsic::ID ID = II->getIntrinsicID();
    111     RecurKind RK = getRK(ID);
    112 
    113     Value *Rdx = nullptr;
    114     IRBuilder<> Builder(II);
    115     IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
    116     Builder.setFastMathFlags(FMF);
    117     switch (ID) {
    118     default: llvm_unreachable("Unexpected intrinsic!");
    119     case Intrinsic::vector_reduce_fadd:
    120     case Intrinsic::vector_reduce_fmul: {
    121       // FMFs must be attached to the call, otherwise it's an ordered reduction
    122       // and it can't be handled by generating a shuffle sequence.
    123       Value *Acc = II->getArgOperand(0);
    124       Value *Vec = II->getArgOperand(1);
    125       if (!FMF.allowReassoc())
    126         Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
    127       else {
    128         if (!isPowerOf2_32(
    129                 cast<FixedVectorType>(Vec->getType())->getNumElements()))
    130           continue;
    131 
    132         Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
    133         Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
    134                                   Acc, Rdx, "bin.rdx");
    135       }
    136       break;
    137     }
    138     case Intrinsic::vector_reduce_add:
    139     case Intrinsic::vector_reduce_mul:
    140     case Intrinsic::vector_reduce_and:
    141     case Intrinsic::vector_reduce_or:
    142     case Intrinsic::vector_reduce_xor:
    143     case Intrinsic::vector_reduce_smax:
    144     case Intrinsic::vector_reduce_smin:
    145     case Intrinsic::vector_reduce_umax:
    146     case Intrinsic::vector_reduce_umin: {
    147       Value *Vec = II->getArgOperand(0);
    148       if (!isPowerOf2_32(
    149               cast<FixedVectorType>(Vec->getType())->getNumElements()))
    150         continue;
    151 
    152       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
    153       break;
    154     }
    155     case Intrinsic::vector_reduce_fmax:
    156     case Intrinsic::vector_reduce_fmin: {
    157       // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
    158       // semantics of the reduction.
    159       Value *Vec = II->getArgOperand(0);
    160       if (!isPowerOf2_32(
    161               cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
    162           !FMF.noNaNs())
    163         continue;
    164 
    165       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
    166       break;
    167     }
    168     }
    169     II->replaceAllUsesWith(Rdx);
    170     II->eraseFromParent();
    171     Changed = true;
    172   }
    173   return Changed;
    174 }
    175 
    176 class ExpandReductions : public FunctionPass {
    177 public:
    178   static char ID;
    179   ExpandReductions() : FunctionPass(ID) {
    180     initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
    181   }
    182 
    183   bool runOnFunction(Function &F) override {
    184     const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
    185     return expandReductions(F, TTI);
    186   }
    187 
    188   void getAnalysisUsage(AnalysisUsage &AU) const override {
    189     AU.addRequired<TargetTransformInfoWrapperPass>();
    190     AU.setPreservesCFG();
    191   }
    192 };
    193 }
    194 
    195 char ExpandReductions::ID;
    196 INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
    197                       "Expand reduction intrinsics", false, false)
    198 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
    199 INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
    200                     "Expand reduction intrinsics", false, false)
    201 
    202 FunctionPass *llvm::createExpandReductionsPass() {
    203   return new ExpandReductions();
    204 }
    205 
    206 PreservedAnalyses ExpandReductionsPass::run(Function &F,
    207                                             FunctionAnalysisManager &AM) {
    208   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
    209   if (!expandReductions(F, &TTI))
    210     return PreservedAnalyses::all();
    211   PreservedAnalyses PA;
    212   PA.preserveSet<CFGAnalyses>();
    213   return PA;
    214 }
    215