Home | History | Annotate | Line # | Download | only in Scalar
      1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
      2 //                                    instrinsics
      3 //
      4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      5 // See https://llvm.org/LICENSE.txt for license information.
      6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This pass replaces masked memory intrinsics - when unsupported by the target
     11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
     12 // appropriate mask bit is set.
     13 //
     14 //===----------------------------------------------------------------------===//
     15 
     16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
     17 #include "llvm/ADT/Twine.h"
     18 #include "llvm/Analysis/DomTreeUpdater.h"
     19 #include "llvm/Analysis/TargetTransformInfo.h"
     20 #include "llvm/IR/BasicBlock.h"
     21 #include "llvm/IR/Constant.h"
     22 #include "llvm/IR/Constants.h"
     23 #include "llvm/IR/DerivedTypes.h"
     24 #include "llvm/IR/Dominators.h"
     25 #include "llvm/IR/Function.h"
     26 #include "llvm/IR/IRBuilder.h"
     27 #include "llvm/IR/InstrTypes.h"
     28 #include "llvm/IR/Instruction.h"
     29 #include "llvm/IR/Instructions.h"
     30 #include "llvm/IR/IntrinsicInst.h"
     31 #include "llvm/IR/Intrinsics.h"
     32 #include "llvm/IR/Type.h"
     33 #include "llvm/IR/Value.h"
     34 #include "llvm/InitializePasses.h"
     35 #include "llvm/Pass.h"
     36 #include "llvm/Support/Casting.h"
     37 #include "llvm/Transforms/Scalar.h"
     38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     39 #include <algorithm>
     40 #include <cassert>
     41 
     42 using namespace llvm;
     43 
     44 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
     45 
     46 namespace {
     47 
     48 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
     49 public:
     50   static char ID; // Pass identification, replacement for typeid
     51 
     52   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
     53     initializeScalarizeMaskedMemIntrinLegacyPassPass(
     54         *PassRegistry::getPassRegistry());
     55   }
     56 
     57   bool runOnFunction(Function &F) override;
     58 
     59   StringRef getPassName() const override {
     60     return "Scalarize Masked Memory Intrinsics";
     61   }
     62 
     63   void getAnalysisUsage(AnalysisUsage &AU) const override {
     64     AU.addRequired<TargetTransformInfoWrapperPass>();
     65     AU.addPreserved<DominatorTreeWrapperPass>();
     66   }
     67 };
     68 
     69 } // end anonymous namespace
     70 
     71 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
     72                           const TargetTransformInfo &TTI, const DataLayout &DL,
     73                           DomTreeUpdater *DTU);
     74 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
     75                              const TargetTransformInfo &TTI,
     76                              const DataLayout &DL, DomTreeUpdater *DTU);
     77 
     78 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
     79 
     80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
     81                       "Scalarize unsupported masked memory intrinsics", false,
     82                       false)
     83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
     84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
     85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
     86                     "Scalarize unsupported masked memory intrinsics", false,
     87                     false)
     88 
     89 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
     90   return new ScalarizeMaskedMemIntrinLegacyPass();
     91 }
     92 
     93 static bool isConstantIntVector(Value *Mask) {
     94   Constant *C = dyn_cast<Constant>(Mask);
     95   if (!C)
     96     return false;
     97 
     98   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
     99   for (unsigned i = 0; i != NumElts; ++i) {
    100     Constant *CElt = C->getAggregateElement(i);
    101     if (!CElt || !isa<ConstantInt>(CElt))
    102       return false;
    103   }
    104 
    105   return true;
    106 }
    107 
    108 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
    109                                 unsigned Idx) {
    110   return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
    111 }
    112 
    113 // Translate a masked load intrinsic like
    114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
    115 //                               <16 x i1> %mask, <16 x i32> %passthru)
    116 // to a chain of basic blocks, with loading element one-by-one if
    117 // the appropriate mask bit is set
    118 //
    119 //  %1 = bitcast i8* %addr to i32*
    120 //  %2 = extractelement <16 x i1> %mask, i32 0
    121 //  br i1 %2, label %cond.load, label %else
    122 //
    123 // cond.load:                                        ; preds = %0
    124 //  %3 = getelementptr i32* %1, i32 0
    125 //  %4 = load i32* %3
    126 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
    127 //  br label %else
    128 //
    129 // else:                                             ; preds = %0, %cond.load
    130 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
    131 //  %6 = extractelement <16 x i1> %mask, i32 1
    132 //  br i1 %6, label %cond.load1, label %else2
    133 //
    134 // cond.load1:                                       ; preds = %else
    135 //  %7 = getelementptr i32* %1, i32 1
    136 //  %8 = load i32* %7
    137 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
    138 //  br label %else2
    139 //
    140 // else2:                                          ; preds = %else, %cond.load1
    141 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
    142 //  %10 = extractelement <16 x i1> %mask, i32 2
    143 //  br i1 %10, label %cond.load4, label %else5
    144 //
    145 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
    146                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
    147   Value *Ptr = CI->getArgOperand(0);
    148   Value *Alignment = CI->getArgOperand(1);
    149   Value *Mask = CI->getArgOperand(2);
    150   Value *Src0 = CI->getArgOperand(3);
    151 
    152   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
    153   VectorType *VecType = cast<FixedVectorType>(CI->getType());
    154 
    155   Type *EltTy = VecType->getElementType();
    156 
    157   IRBuilder<> Builder(CI->getContext());
    158   Instruction *InsertPt = CI;
    159   BasicBlock *IfBlock = CI->getParent();
    160 
    161   Builder.SetInsertPoint(InsertPt);
    162   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    163 
    164   // Short-cut if the mask is all-true.
    165   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
    166     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
    167     CI->replaceAllUsesWith(NewI);
    168     CI->eraseFromParent();
    169     return;
    170   }
    171 
    172   // Adjust alignment for the scalar instruction.
    173   const Align AdjustedAlignVal =
    174       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
    175   // Bitcast %addr from i8* to EltTy*
    176   Type *NewPtrType =
    177       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
    178   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
    179   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
    180 
    181   // The result vector
    182   Value *VResult = Src0;
    183 
    184   if (isConstantIntVector(Mask)) {
    185     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    186       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
    187         continue;
    188       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
    189       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
    190       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
    191     }
    192     CI->replaceAllUsesWith(VResult);
    193     CI->eraseFromParent();
    194     return;
    195   }
    196 
    197   // If the mask is not v1i1, use scalar bit test operations. This generates
    198   // better results on X86 at least.
    199   Value *SclrMask;
    200   if (VectorWidth != 1) {
    201     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
    202     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
    203   }
    204 
    205   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    206     // Fill the "else" block, created in the previous iteration
    207     //
    208     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
    209     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
    210     //  %cond = icmp ne i16 %mask_1, 0
    211     //  br i1 %mask_1, label %cond.load, label %else
    212     //
    213     Value *Predicate;
    214     if (VectorWidth != 1) {
    215       Value *Mask = Builder.getInt(APInt::getOneBitSet(
    216           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
    217       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
    218                                        Builder.getIntN(VectorWidth, 0));
    219     } else {
    220       Predicate = Builder.CreateExtractElement(Mask, Idx);
    221     }
    222 
    223     // Create "cond" block
    224     //
    225     //  %EltAddr = getelementptr i32* %1, i32 0
    226     //  %Elt = load i32* %EltAddr
    227     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
    228     //
    229     Instruction *ThenTerm =
    230         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
    231                                   /*BranchWeights=*/nullptr, DTU);
    232 
    233     BasicBlock *CondBlock = ThenTerm->getParent();
    234     CondBlock->setName("cond.load");
    235 
    236     Builder.SetInsertPoint(CondBlock->getTerminator());
    237     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
    238     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
    239     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
    240 
    241     // Create "else" block, fill it in the next iteration
    242     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
    243     NewIfBlock->setName("else");
    244     BasicBlock *PrevIfBlock = IfBlock;
    245     IfBlock = NewIfBlock;
    246 
    247     // Create the phi to join the new and previous value.
    248     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
    249     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
    250     Phi->addIncoming(NewVResult, CondBlock);
    251     Phi->addIncoming(VResult, PrevIfBlock);
    252     VResult = Phi;
    253   }
    254 
    255   CI->replaceAllUsesWith(VResult);
    256   CI->eraseFromParent();
    257 
    258   ModifiedDT = true;
    259 }
    260 
    261 // Translate a masked store intrinsic, like
    262 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
    263 //                               <16 x i1> %mask)
    264 // to a chain of basic blocks, that stores element one-by-one if
    265 // the appropriate mask bit is set
    266 //
    267 //   %1 = bitcast i8* %addr to i32*
    268 //   %2 = extractelement <16 x i1> %mask, i32 0
    269 //   br i1 %2, label %cond.store, label %else
    270 //
    271 // cond.store:                                       ; preds = %0
    272 //   %3 = extractelement <16 x i32> %val, i32 0
    273 //   %4 = getelementptr i32* %1, i32 0
    274 //   store i32 %3, i32* %4
    275 //   br label %else
    276 //
    277 // else:                                             ; preds = %0, %cond.store
    278 //   %5 = extractelement <16 x i1> %mask, i32 1
    279 //   br i1 %5, label %cond.store1, label %else2
    280 //
    281 // cond.store1:                                      ; preds = %else
    282 //   %6 = extractelement <16 x i32> %val, i32 1
    283 //   %7 = getelementptr i32* %1, i32 1
    284 //   store i32 %6, i32* %7
    285 //   br label %else2
    286 //   . . .
    287 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
    288                                  DomTreeUpdater *DTU, bool &ModifiedDT) {
    289   Value *Src = CI->getArgOperand(0);
    290   Value *Ptr = CI->getArgOperand(1);
    291   Value *Alignment = CI->getArgOperand(2);
    292   Value *Mask = CI->getArgOperand(3);
    293 
    294   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
    295   auto *VecType = cast<VectorType>(Src->getType());
    296 
    297   Type *EltTy = VecType->getElementType();
    298 
    299   IRBuilder<> Builder(CI->getContext());
    300   Instruction *InsertPt = CI;
    301   Builder.SetInsertPoint(InsertPt);
    302   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    303 
    304   // Short-cut if the mask is all-true.
    305   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
    306     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
    307     CI->eraseFromParent();
    308     return;
    309   }
    310 
    311   // Adjust alignment for the scalar instruction.
    312   const Align AdjustedAlignVal =
    313       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
    314   // Bitcast %addr from i8* to EltTy*
    315   Type *NewPtrType =
    316       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
    317   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
    318   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
    319 
    320   if (isConstantIntVector(Mask)) {
    321     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    322       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
    323         continue;
    324       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
    325       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
    326       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
    327     }
    328     CI->eraseFromParent();
    329     return;
    330   }
    331 
    332   // If the mask is not v1i1, use scalar bit test operations. This generates
    333   // better results on X86 at least.
    334   Value *SclrMask;
    335   if (VectorWidth != 1) {
    336     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
    337     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
    338   }
    339 
    340   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    341     // Fill the "else" block, created in the previous iteration
    342     //
    343     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
    344     //  %cond = icmp ne i16 %mask_1, 0
    345     //  br i1 %mask_1, label %cond.store, label %else
    346     //
    347     Value *Predicate;
    348     if (VectorWidth != 1) {
    349       Value *Mask = Builder.getInt(APInt::getOneBitSet(
    350           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
    351       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
    352                                        Builder.getIntN(VectorWidth, 0));
    353     } else {
    354       Predicate = Builder.CreateExtractElement(Mask, Idx);
    355     }
    356 
    357     // Create "cond" block
    358     //
    359     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
    360     //  %EltAddr = getelementptr i32* %1, i32 0
    361     //  %store i32 %OneElt, i32* %EltAddr
    362     //
    363     Instruction *ThenTerm =
    364         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
    365                                   /*BranchWeights=*/nullptr, DTU);
    366 
    367     BasicBlock *CondBlock = ThenTerm->getParent();
    368     CondBlock->setName("cond.store");
    369 
    370     Builder.SetInsertPoint(CondBlock->getTerminator());
    371     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
    372     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
    373     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
    374 
    375     // Create "else" block, fill it in the next iteration
    376     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
    377     NewIfBlock->setName("else");
    378 
    379     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
    380   }
    381   CI->eraseFromParent();
    382 
    383   ModifiedDT = true;
    384 }
    385 
    386 // Translate a masked gather intrinsic like
    387 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
    388 //                               <16 x i1> %Mask, <16 x i32> %Src)
    389 // to a chain of basic blocks, with loading element one-by-one if
    390 // the appropriate mask bit is set
    391 //
    392 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
    393 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
    394 // br i1 %Mask0, label %cond.load, label %else
    395 //
    396 // cond.load:
    397 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
    398 // %Load0 = load i32, i32* %Ptr0, align 4
    399 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
    400 // br label %else
    401 //
    402 // else:
    403 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
    404 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
    405 // br i1 %Mask1, label %cond.load1, label %else2
    406 //
    407 // cond.load1:
    408 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
    409 // %Load1 = load i32, i32* %Ptr1, align 4
    410 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
    411 // br label %else2
    412 // . . .
    413 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
    414 // ret <16 x i32> %Result
    415 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
    416                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
    417   Value *Ptrs = CI->getArgOperand(0);
    418   Value *Alignment = CI->getArgOperand(1);
    419   Value *Mask = CI->getArgOperand(2);
    420   Value *Src0 = CI->getArgOperand(3);
    421 
    422   auto *VecType = cast<FixedVectorType>(CI->getType());
    423   Type *EltTy = VecType->getElementType();
    424 
    425   IRBuilder<> Builder(CI->getContext());
    426   Instruction *InsertPt = CI;
    427   BasicBlock *IfBlock = CI->getParent();
    428   Builder.SetInsertPoint(InsertPt);
    429   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
    430 
    431   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    432 
    433   // The result vector
    434   Value *VResult = Src0;
    435   unsigned VectorWidth = VecType->getNumElements();
    436 
    437   // Shorten the way if the mask is a vector of constants.
    438   if (isConstantIntVector(Mask)) {
    439     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    440       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
    441         continue;
    442       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
    443       LoadInst *Load =
    444           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
    445       VResult =
    446           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
    447     }
    448     CI->replaceAllUsesWith(VResult);
    449     CI->eraseFromParent();
    450     return;
    451   }
    452 
    453   // If the mask is not v1i1, use scalar bit test operations. This generates
    454   // better results on X86 at least.
    455   Value *SclrMask;
    456   if (VectorWidth != 1) {
    457     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
    458     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
    459   }
    460 
    461   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    462     // Fill the "else" block, created in the previous iteration
    463     //
    464     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
    465     //  %cond = icmp ne i16 %mask_1, 0
    466     //  br i1 %Mask1, label %cond.load, label %else
    467     //
    468 
    469     Value *Predicate;
    470     if (VectorWidth != 1) {
    471       Value *Mask = Builder.getInt(APInt::getOneBitSet(
    472           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
    473       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
    474                                        Builder.getIntN(VectorWidth, 0));
    475     } else {
    476       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
    477     }
    478 
    479     // Create "cond" block
    480     //
    481     //  %EltAddr = getelementptr i32* %1, i32 0
    482     //  %Elt = load i32* %EltAddr
    483     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
    484     //
    485     Instruction *ThenTerm =
    486         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
    487                                   /*BranchWeights=*/nullptr, DTU);
    488 
    489     BasicBlock *CondBlock = ThenTerm->getParent();
    490     CondBlock->setName("cond.load");
    491 
    492     Builder.SetInsertPoint(CondBlock->getTerminator());
    493     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
    494     LoadInst *Load =
    495         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
    496     Value *NewVResult =
    497         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
    498 
    499     // Create "else" block, fill it in the next iteration
    500     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
    501     NewIfBlock->setName("else");
    502     BasicBlock *PrevIfBlock = IfBlock;
    503     IfBlock = NewIfBlock;
    504 
    505     // Create the phi to join the new and previous value.
    506     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
    507     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
    508     Phi->addIncoming(NewVResult, CondBlock);
    509     Phi->addIncoming(VResult, PrevIfBlock);
    510     VResult = Phi;
    511   }
    512 
    513   CI->replaceAllUsesWith(VResult);
    514   CI->eraseFromParent();
    515 
    516   ModifiedDT = true;
    517 }
    518 
    519 // Translate a masked scatter intrinsic, like
    520 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
    521 //                                  <16 x i1> %Mask)
    522 // to a chain of basic blocks, that stores element one-by-one if
    523 // the appropriate mask bit is set.
    524 //
    525 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
    526 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
    527 // br i1 %Mask0, label %cond.store, label %else
    528 //
    529 // cond.store:
    530 // %Elt0 = extractelement <16 x i32> %Src, i32 0
    531 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
    532 // store i32 %Elt0, i32* %Ptr0, align 4
    533 // br label %else
    534 //
    535 // else:
    536 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
    537 // br i1 %Mask1, label %cond.store1, label %else2
    538 //
    539 // cond.store1:
    540 // %Elt1 = extractelement <16 x i32> %Src, i32 1
    541 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
    542 // store i32 %Elt1, i32* %Ptr1, align 4
    543 // br label %else2
    544 //   . . .
    545 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
    546                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
    547   Value *Src = CI->getArgOperand(0);
    548   Value *Ptrs = CI->getArgOperand(1);
    549   Value *Alignment = CI->getArgOperand(2);
    550   Value *Mask = CI->getArgOperand(3);
    551 
    552   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
    553 
    554   assert(
    555       isa<VectorType>(Ptrs->getType()) &&
    556       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
    557       "Vector of pointers is expected in masked scatter intrinsic");
    558 
    559   IRBuilder<> Builder(CI->getContext());
    560   Instruction *InsertPt = CI;
    561   Builder.SetInsertPoint(InsertPt);
    562   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    563 
    564   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
    565   unsigned VectorWidth = SrcFVTy->getNumElements();
    566 
    567   // Shorten the way if the mask is a vector of constants.
    568   if (isConstantIntVector(Mask)) {
    569     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    570       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
    571         continue;
    572       Value *OneElt =
    573           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
    574       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
    575       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
    576     }
    577     CI->eraseFromParent();
    578     return;
    579   }
    580 
    581   // If the mask is not v1i1, use scalar bit test operations. This generates
    582   // better results on X86 at least.
    583   Value *SclrMask;
    584   if (VectorWidth != 1) {
    585     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
    586     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
    587   }
    588 
    589   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    590     // Fill the "else" block, created in the previous iteration
    591     //
    592     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
    593     //  %cond = icmp ne i16 %mask_1, 0
    594     //  br i1 %Mask1, label %cond.store, label %else
    595     //
    596     Value *Predicate;
    597     if (VectorWidth != 1) {
    598       Value *Mask = Builder.getInt(APInt::getOneBitSet(
    599           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
    600       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
    601                                        Builder.getIntN(VectorWidth, 0));
    602     } else {
    603       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
    604     }
    605 
    606     // Create "cond" block
    607     //
    608     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
    609     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
    610     //  %store i32 %Elt1, i32* %Ptr1
    611     //
    612     Instruction *ThenTerm =
    613         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
    614                                   /*BranchWeights=*/nullptr, DTU);
    615 
    616     BasicBlock *CondBlock = ThenTerm->getParent();
    617     CondBlock->setName("cond.store");
    618 
    619     Builder.SetInsertPoint(CondBlock->getTerminator());
    620     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
    621     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
    622     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
    623 
    624     // Create "else" block, fill it in the next iteration
    625     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
    626     NewIfBlock->setName("else");
    627 
    628     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
    629   }
    630   CI->eraseFromParent();
    631 
    632   ModifiedDT = true;
    633 }
    634 
    635 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
    636                                       DomTreeUpdater *DTU, bool &ModifiedDT) {
    637   Value *Ptr = CI->getArgOperand(0);
    638   Value *Mask = CI->getArgOperand(1);
    639   Value *PassThru = CI->getArgOperand(2);
    640 
    641   auto *VecType = cast<FixedVectorType>(CI->getType());
    642 
    643   Type *EltTy = VecType->getElementType();
    644 
    645   IRBuilder<> Builder(CI->getContext());
    646   Instruction *InsertPt = CI;
    647   BasicBlock *IfBlock = CI->getParent();
    648 
    649   Builder.SetInsertPoint(InsertPt);
    650   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    651 
    652   unsigned VectorWidth = VecType->getNumElements();
    653 
    654   // The result vector
    655   Value *VResult = PassThru;
    656 
    657   // Shorten the way if the mask is a vector of constants.
    658   // Create a build_vector pattern, with loads/undefs as necessary and then
    659   // shuffle blend with the pass through value.
    660   if (isConstantIntVector(Mask)) {
    661     unsigned MemIndex = 0;
    662     VResult = UndefValue::get(VecType);
    663     SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
    664     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    665       Value *InsertElt;
    666       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
    667         InsertElt = UndefValue::get(EltTy);
    668         ShuffleMask[Idx] = Idx + VectorWidth;
    669       } else {
    670         Value *NewPtr =
    671             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
    672         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
    673                                               "Load" + Twine(Idx));
    674         ShuffleMask[Idx] = Idx;
    675         ++MemIndex;
    676       }
    677       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
    678                                             "Res" + Twine(Idx));
    679     }
    680     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
    681     CI->replaceAllUsesWith(VResult);
    682     CI->eraseFromParent();
    683     return;
    684   }
    685 
    686   // If the mask is not v1i1, use scalar bit test operations. This generates
    687   // better results on X86 at least.
    688   Value *SclrMask;
    689   if (VectorWidth != 1) {
    690     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
    691     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
    692   }
    693 
    694   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    695     // Fill the "else" block, created in the previous iteration
    696     //
    697     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
    698     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
    699     //  br i1 %mask_1, label %cond.load, label %else
    700     //
    701 
    702     Value *Predicate;
    703     if (VectorWidth != 1) {
    704       Value *Mask = Builder.getInt(APInt::getOneBitSet(
    705           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
    706       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
    707                                        Builder.getIntN(VectorWidth, 0));
    708     } else {
    709       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
    710     }
    711 
    712     // Create "cond" block
    713     //
    714     //  %EltAddr = getelementptr i32* %1, i32 0
    715     //  %Elt = load i32* %EltAddr
    716     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
    717     //
    718     Instruction *ThenTerm =
    719         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
    720                                   /*BranchWeights=*/nullptr, DTU);
    721 
    722     BasicBlock *CondBlock = ThenTerm->getParent();
    723     CondBlock->setName("cond.load");
    724 
    725     Builder.SetInsertPoint(CondBlock->getTerminator());
    726     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
    727     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
    728 
    729     // Move the pointer if there are more blocks to come.
    730     Value *NewPtr;
    731     if ((Idx + 1) != VectorWidth)
    732       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
    733 
    734     // Create "else" block, fill it in the next iteration
    735     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
    736     NewIfBlock->setName("else");
    737     BasicBlock *PrevIfBlock = IfBlock;
    738     IfBlock = NewIfBlock;
    739 
    740     // Create the phi to join the new and previous value.
    741     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
    742     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
    743     ResultPhi->addIncoming(NewVResult, CondBlock);
    744     ResultPhi->addIncoming(VResult, PrevIfBlock);
    745     VResult = ResultPhi;
    746 
    747     // Add a PHI for the pointer if this isn't the last iteration.
    748     if ((Idx + 1) != VectorWidth) {
    749       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
    750       PtrPhi->addIncoming(NewPtr, CondBlock);
    751       PtrPhi->addIncoming(Ptr, PrevIfBlock);
    752       Ptr = PtrPhi;
    753     }
    754   }
    755 
    756   CI->replaceAllUsesWith(VResult);
    757   CI->eraseFromParent();
    758 
    759   ModifiedDT = true;
    760 }
    761 
    762 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
    763                                          DomTreeUpdater *DTU,
    764                                          bool &ModifiedDT) {
    765   Value *Src = CI->getArgOperand(0);
    766   Value *Ptr = CI->getArgOperand(1);
    767   Value *Mask = CI->getArgOperand(2);
    768 
    769   auto *VecType = cast<FixedVectorType>(Src->getType());
    770 
    771   IRBuilder<> Builder(CI->getContext());
    772   Instruction *InsertPt = CI;
    773   BasicBlock *IfBlock = CI->getParent();
    774 
    775   Builder.SetInsertPoint(InsertPt);
    776   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    777 
    778   Type *EltTy = VecType->getElementType();
    779 
    780   unsigned VectorWidth = VecType->getNumElements();
    781 
    782   // Shorten the way if the mask is a vector of constants.
    783   if (isConstantIntVector(Mask)) {
    784     unsigned MemIndex = 0;
    785     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    786       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
    787         continue;
    788       Value *OneElt =
    789           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
    790       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
    791       Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
    792       ++MemIndex;
    793     }
    794     CI->eraseFromParent();
    795     return;
    796   }
    797 
    798   // If the mask is not v1i1, use scalar bit test operations. This generates
    799   // better results on X86 at least.
    800   Value *SclrMask;
    801   if (VectorWidth != 1) {
    802     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
    803     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
    804   }
    805 
    806   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    807     // Fill the "else" block, created in the previous iteration
    808     //
    809     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
    810     //  br i1 %mask_1, label %cond.store, label %else
    811     //
    812     Value *Predicate;
    813     if (VectorWidth != 1) {
    814       Value *Mask = Builder.getInt(APInt::getOneBitSet(
    815           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
    816       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
    817                                        Builder.getIntN(VectorWidth, 0));
    818     } else {
    819       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
    820     }
    821 
    822     // Create "cond" block
    823     //
    824     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
    825     //  %EltAddr = getelementptr i32* %1, i32 0
    826     //  %store i32 %OneElt, i32* %EltAddr
    827     //
    828     Instruction *ThenTerm =
    829         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
    830                                   /*BranchWeights=*/nullptr, DTU);
    831 
    832     BasicBlock *CondBlock = ThenTerm->getParent();
    833     CondBlock->setName("cond.store");
    834 
    835     Builder.SetInsertPoint(CondBlock->getTerminator());
    836     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
    837     Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
    838 
    839     // Move the pointer if there are more blocks to come.
    840     Value *NewPtr;
    841     if ((Idx + 1) != VectorWidth)
    842       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
    843 
    844     // Create "else" block, fill it in the next iteration
    845     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
    846     NewIfBlock->setName("else");
    847     BasicBlock *PrevIfBlock = IfBlock;
    848     IfBlock = NewIfBlock;
    849 
    850     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
    851 
    852     // Add a PHI for the pointer if this isn't the last iteration.
    853     if ((Idx + 1) != VectorWidth) {
    854       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
    855       PtrPhi->addIncoming(NewPtr, CondBlock);
    856       PtrPhi->addIncoming(Ptr, PrevIfBlock);
    857       Ptr = PtrPhi;
    858     }
    859   }
    860   CI->eraseFromParent();
    861 
    862   ModifiedDT = true;
    863 }
    864 
    865 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
    866                     DominatorTree *DT) {
    867   Optional<DomTreeUpdater> DTU;
    868   if (DT)
    869     DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
    870 
    871   bool EverMadeChange = false;
    872   bool MadeChange = true;
    873   auto &DL = F.getParent()->getDataLayout();
    874   while (MadeChange) {
    875     MadeChange = false;
    876     for (Function::iterator I = F.begin(); I != F.end();) {
    877       BasicBlock *BB = &*I++;
    878       bool ModifiedDTOnIteration = false;
    879       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL,
    880                                   DTU.hasValue() ? DTU.getPointer() : nullptr);
    881 
    882 
    883       // Restart BB iteration if the dominator tree of the Function was changed
    884       if (ModifiedDTOnIteration)
    885         break;
    886     }
    887 
    888     EverMadeChange |= MadeChange;
    889   }
    890   return EverMadeChange;
    891 }
    892 
    893 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
    894   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
    895   DominatorTree *DT = nullptr;
    896   if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
    897     DT = &DTWP->getDomTree();
    898   return runImpl(F, TTI, DT);
    899 }
    900 
    901 PreservedAnalyses
    902 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
    903   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
    904   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
    905   if (!runImpl(F, TTI, DT))
    906     return PreservedAnalyses::all();
    907   PreservedAnalyses PA;
    908   PA.preserve<TargetIRAnalysis>();
    909   PA.preserve<DominatorTreeAnalysis>();
    910   return PA;
    911 }
    912 
    913 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
    914                           const TargetTransformInfo &TTI, const DataLayout &DL,
    915                           DomTreeUpdater *DTU) {
    916   bool MadeChange = false;
    917 
    918   BasicBlock::iterator CurInstIterator = BB.begin();
    919   while (CurInstIterator != BB.end()) {
    920     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
    921       MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
    922     if (ModifiedDT)
    923       return true;
    924   }
    925 
    926   return MadeChange;
    927 }
    928 
    929 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
    930                              const TargetTransformInfo &TTI,
    931                              const DataLayout &DL, DomTreeUpdater *DTU) {
    932   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
    933   if (II) {
    934     // The scalarization code below does not work for scalable vectors.
    935     if (isa<ScalableVectorType>(II->getType()) ||
    936         any_of(II->arg_operands(),
    937                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
    938       return false;
    939 
    940     switch (II->getIntrinsicID()) {
    941     default:
    942       break;
    943     case Intrinsic::masked_load:
    944       // Scalarize unsupported vector masked load
    945       if (TTI.isLegalMaskedLoad(
    946               CI->getType(),
    947               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
    948         return false;
    949       scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
    950       return true;
    951     case Intrinsic::masked_store:
    952       if (TTI.isLegalMaskedStore(
    953               CI->getArgOperand(0)->getType(),
    954               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
    955         return false;
    956       scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
    957       return true;
    958     case Intrinsic::masked_gather: {
    959       unsigned AlignmentInt =
    960           cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
    961       Type *LoadTy = CI->getType();
    962       Align Alignment =
    963           DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
    964       if (TTI.isLegalMaskedGather(LoadTy, Alignment))
    965         return false;
    966       scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
    967       return true;
    968     }
    969     case Intrinsic::masked_scatter: {
    970       unsigned AlignmentInt =
    971           cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
    972       Type *StoreTy = CI->getArgOperand(0)->getType();
    973       Align Alignment =
    974           DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
    975       if (TTI.isLegalMaskedScatter(StoreTy, Alignment))
    976         return false;
    977       scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
    978       return true;
    979     }
    980     case Intrinsic::masked_expandload:
    981       if (TTI.isLegalMaskedExpandLoad(CI->getType()))
    982         return false;
    983       scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
    984       return true;
    985     case Intrinsic::masked_compressstore:
    986       if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
    987         return false;
    988       scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
    989       return true;
    990     }
    991   }
    992 
    993   return false;
    994 }
    995