Home | History | Annotate | Line # | Download | only in X86
      1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 /// \file Pass to transform amx intrinsics to scalar operations.
     10 /// This pass is always enabled and it skips when it is not -O0 and has no
     11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
     12 /// intrinsics is near the amx intrinsics code. We are not able to find a
     13 /// point which post-dominate all the shape and dominate all amx intrinsics.
     14 /// To decouple the dependency of the shape, we transform amx intrinsics
     15 /// to scalar operation, so that compiling doesn't fail. In long term, we
     16 /// should improve fast register allocation to allocate amx register.
     17 //===----------------------------------------------------------------------===//
     18 //
     19 #include "X86.h"
     20 #include "llvm/ADT/DenseSet.h"
     21 #include "llvm/ADT/PostOrderIterator.h"
     22 #include "llvm/Analysis/DomTreeUpdater.h"
     23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
     24 #include "llvm/Analysis/TargetTransformInfo.h"
     25 #include "llvm/CodeGen/Passes.h"
     26 #include "llvm/CodeGen/TargetPassConfig.h"
     27 #include "llvm/CodeGen/ValueTypes.h"
     28 #include "llvm/IR/DataLayout.h"
     29 #include "llvm/IR/Function.h"
     30 #include "llvm/IR/IRBuilder.h"
     31 #include "llvm/IR/Instructions.h"
     32 #include "llvm/IR/IntrinsicInst.h"
     33 #include "llvm/IR/IntrinsicsX86.h"
     34 #include "llvm/IR/PatternMatch.h"
     35 #include "llvm/InitializePasses.h"
     36 #include "llvm/Pass.h"
     37 #include "llvm/Support/CommandLine.h"
     38 #include "llvm/Target/TargetMachine.h"
     39 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     40 #include "llvm/Transforms/Utils/LoopUtils.h"
     41 
     42 using namespace llvm;
     43 using namespace PatternMatch;
     44 
     45 #define DEBUG_TYPE "lower-amx-intrinsics"
     46 
     47 #ifndef NDEBUG
     48 static bool isV256I32Ty(Type *Ty) {
     49   if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
     50     return FVT->getNumElements() == 256 &&
     51            FVT->getElementType()->isIntegerTy(32);
     52   return false;
     53 }
     54 #endif
     55 
     56 static cl::opt<bool>
     57     X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
     58                     cl::desc("X86: enable AMX scalarizition."));
     59 
     60 namespace {
     61 class X86LowerAMXIntrinsics {
     62   Function &Func;
     63 
     64 public:
     65   X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
     66       : Func(F), DTU(DomTU), LI(LoopI) {}
     67   bool visit();
     68 
     69 private:
     70   DomTreeUpdater &DTU;
     71   LoopInfo *LI;
     72   BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
     73                          Value *Step, StringRef Name, IRBuilderBase &B,
     74                          Loop *L);
     75   template <bool IsTileLoad>
     76   Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
     77                                   IRBuilderBase &B, Value *Row, Value *Col,
     78                                   Value *Ptr, Value *Stride, Value *Tile);
     79   template <Intrinsic::ID IntrID>
     80   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
     81                               IntrID == Intrinsic::x86_tdpbsud_internal ||
     82                               IntrID == Intrinsic::x86_tdpbusd_internal ||
     83                               IntrID == Intrinsic::x86_tdpbuud_internal ||
     84                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
     85                           Value *>::type
     86   createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
     87                     Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
     88                     Value *RHS);
     89   template <bool IsTileLoad>
     90   bool lowerTileLoadStore(Instruction *TileLoadStore);
     91   template <Intrinsic::ID IntrID>
     92   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
     93                               IntrID == Intrinsic::x86_tdpbsud_internal ||
     94                               IntrID == Intrinsic::x86_tdpbusd_internal ||
     95                               IntrID == Intrinsic::x86_tdpbuud_internal ||
     96                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
     97                           bool>::type
     98   lowerTileDP(Instruction *TileDP);
     99   bool lowerTileZero(Instruction *TileZero);
    100 };
    101 } // anonymous namespace
    102 
    103 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
    104                                               BasicBlock *Exit, Value *Bound,
    105                                               Value *Step, StringRef Name,
    106                                               IRBuilderBase &B, Loop *L) {
    107   LLVMContext &Ctx = Preheader->getContext();
    108   BasicBlock *Header =
    109       BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
    110   BasicBlock *Body =
    111       BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
    112   BasicBlock *Latch =
    113       BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
    114 
    115   Type *I16Ty = Type::getInt16Ty(Ctx);
    116   BranchInst::Create(Body, Header);
    117   BranchInst::Create(Latch, Body);
    118   PHINode *IV =
    119       PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
    120   IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
    121 
    122   B.SetInsertPoint(Latch);
    123   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
    124   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
    125   BranchInst::Create(Header, Exit, Cond, Latch);
    126   IV->addIncoming(Inc, Latch);
    127 
    128   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
    129   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
    130   PreheaderBr->setSuccessor(0, Header);
    131   DTU.applyUpdatesPermissive({
    132       {DominatorTree::Delete, Preheader, Tmp},
    133       {DominatorTree::Insert, Header, Body},
    134       {DominatorTree::Insert, Body, Latch},
    135       {DominatorTree::Insert, Latch, Header},
    136       {DominatorTree::Insert, Latch, Exit},
    137       {DominatorTree::Insert, Preheader, Header},
    138   });
    139   if (LI) {
    140     L->addBasicBlockToLoop(Header, *LI);
    141     L->addBasicBlockToLoop(Body, *LI);
    142     L->addBasicBlockToLoop(Latch, *LI);
    143   }
    144   return Body;
    145 }
    146 
    147 template <bool IsTileLoad>
    148 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
    149     BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
    150     Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
    151   std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
    152   Loop *RowLoop = nullptr;
    153   Loop *ColLoop = nullptr;
    154   if (LI) {
    155     RowLoop = LI->AllocateLoop();
    156     ColLoop = LI->AllocateLoop();
    157     RowLoop->addChildLoop(ColLoop);
    158     if (Loop *ParentL = LI->getLoopFor(Start))
    159       ParentL->addChildLoop(RowLoop);
    160     else
    161       LI->addTopLevelLoop(RowLoop);
    162   }
    163 
    164   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
    165                                    IntrinName + ".scalarize.rows", B, RowLoop);
    166   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
    167 
    168   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
    169                                    IntrinName + ".scalarize.cols", B, ColLoop);
    170 
    171   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
    172   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
    173   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
    174   Value *CurrentRow = &*RowLoopHeader->begin();
    175   Value *CurrentCol = &*ColLoopHeader->begin();
    176   Type *EltTy = B.getInt32Ty();
    177   FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
    178 
    179   // Common part for tileload and tilestore
    180   // *.scalarize.cols.body:
    181   // Calculate %idxmem and %idxvec
    182   B.SetInsertPoint(ColBody->getTerminator());
    183   Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
    184   Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
    185   Value *Offset =
    186       B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
    187   unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
    188   Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
    189   Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
    190   Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
    191   if (IsTileLoad) {
    192     // tileload.scalarize.rows.header:
    193     // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
    194     // %tileload.scalarize.rows.latch ]
    195     B.SetInsertPoint(RowLoopHeader->getTerminator());
    196     Value *VecZero = Constant::getNullValue(V256I32Ty);
    197     PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
    198     VecCPhiRowLoop->addIncoming(VecZero, Start);
    199 
    200     // tileload.scalarize.cols.header:
    201     // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
    202     // ], [ %ResVec, %tileload.scalarize.cols.latch ]
    203     B.SetInsertPoint(ColLoopHeader->getTerminator());
    204     PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
    205     VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
    206 
    207     // tileload.scalarize.cols.body:
    208     // Calculate %idxmem and %idxvec
    209     // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
    210     // %elt = load i32, i32* %ptr
    211     // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
    212     B.SetInsertPoint(ColBody->getTerminator());
    213     Value *Elt = B.CreateLoad(EltTy, EltPtr);
    214     Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
    215     VecPhi->addIncoming(ResVec, ColLoopLatch);
    216     VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
    217 
    218     return ResVec;
    219   } else {
    220     auto *BitCast = cast<BitCastInst>(Tile);
    221     Value *Vec = BitCast->getOperand(0);
    222     assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
    223     // tilestore.scalarize.cols.body:
    224     // %mul = mul i16 %row.iv, i16 16
    225     // %idx = add i16 %mul, i16 %col.iv
    226     // %vec = extractelement <16 x i32> %vec, i16 %idx
    227     // store i32 %vec, i32* %ptr
    228     B.SetInsertPoint(ColBody->getTerminator());
    229     Value *Elt = B.CreateExtractElement(Vec, Idx);
    230 
    231     B.CreateStore(Elt, EltPtr);
    232     return nullptr;
    233   }
    234 }
    235 
    236 template <Intrinsic::ID IntrID>
    237 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
    238                             IntrID == Intrinsic::x86_tdpbsud_internal ||
    239                             IntrID == Intrinsic::x86_tdpbusd_internal ||
    240                             IntrID == Intrinsic::x86_tdpbuud_internal ||
    241                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
    242                         Value *>::type
    243 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
    244                                          IRBuilderBase &B, Value *Row,
    245                                          Value *Col, Value *K, Value *Acc,
    246                                          Value *LHS, Value *RHS) {
    247   std::string IntrinName;
    248   switch (IntrID) {
    249   case Intrinsic::x86_tdpbssd_internal:
    250     IntrinName = "tiledpbssd";
    251     break;
    252   case Intrinsic::x86_tdpbsud_internal:
    253     IntrinName = "tiledpbsud";
    254     break;
    255   case Intrinsic::x86_tdpbusd_internal:
    256     IntrinName = "tiledpbusd";
    257     break;
    258   case Intrinsic::x86_tdpbuud_internal:
    259     IntrinName = "tiledpbuud";
    260     break;
    261   case Intrinsic::x86_tdpbf16ps_internal:
    262     IntrinName = "tiledpbf16ps";
    263     break;
    264   }
    265   Loop *RowLoop = nullptr;
    266   Loop *ColLoop = nullptr;
    267   Loop *InnerLoop = nullptr;
    268   if (LI) {
    269     RowLoop = LI->AllocateLoop();
    270     ColLoop = LI->AllocateLoop();
    271     InnerLoop = LI->AllocateLoop();
    272     ColLoop->addChildLoop(InnerLoop);
    273     RowLoop->addChildLoop(ColLoop);
    274     if (Loop *ParentL = LI->getLoopFor(Start))
    275       ParentL->addChildLoop(RowLoop);
    276     else
    277       LI->addTopLevelLoop(RowLoop);
    278   }
    279 
    280   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
    281                                    IntrinName + ".scalarize.rows", B, RowLoop);
    282   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
    283 
    284   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
    285                                    IntrinName + ".scalarize.cols", B, ColLoop);
    286 
    287   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
    288 
    289   B.SetInsertPoint(ColBody->getTerminator());
    290   BasicBlock *InnerBody =
    291       createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
    292                  IntrinName + ".scalarize.inner", B, InnerLoop);
    293 
    294   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
    295   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
    296   BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
    297   BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
    298   Value *CurrentRow = &*RowLoopHeader->begin();
    299   Value *CurrentCol = &*ColLoopHeader->begin();
    300   Value *CurrentInner = &*InnerLoopHeader->begin();
    301 
    302   FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
    303   auto *BitCastAcc = cast<BitCastInst>(Acc);
    304   Value *VecC = BitCastAcc->getOperand(0);
    305   assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
    306   // TODO else create BitCast from x86amx to v256i32.
    307   // Store x86amx to memory, and reload from memory
    308   // to vector. However with -O0, it doesn't happen.
    309   auto *BitCastLHS = cast<BitCastInst>(LHS);
    310   Value *VecA = BitCastLHS->getOperand(0);
    311   assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
    312   auto *BitCastRHS = cast<BitCastInst>(RHS);
    313   Value *VecB = BitCastRHS->getOperand(0);
    314   assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
    315 
    316   // tiledpbssd.scalarize.rows.header:
    317   // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
    318   // %tiledpbssd.scalarize.rows.latch ]
    319 
    320   // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
    321   // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
    322   B.SetInsertPoint(RowLoopHeader->getTerminator());
    323   PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
    324   VecCPhiRowLoop->addIncoming(VecC, Start);
    325   Value *VecZero = Constant::getNullValue(V256I32Ty);
    326   PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
    327   VecDPhiRowLoop->addIncoming(VecZero, Start);
    328 
    329   // tiledpbssd.scalarize.cols.header:
    330   // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
    331   // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
    332   // %tiledpbssd.scalarize.cols.latch ]
    333 
    334   // %vec.d.phi.col = phi <256 x i32> [
    335   // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
    336   // %tiledpbssd.scalarize.cols.latch ]
    337 
    338   // calculate idxc.
    339   B.SetInsertPoint(ColLoopHeader->getTerminator());
    340   PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
    341   VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
    342   PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
    343   VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
    344   Value *IdxC =
    345       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
    346 
    347   // tiledpbssd.scalarize.inner.header:
    348   // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
    349   // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
    350   // %tiledpbssd.scalarize.inner.latch ]
    351 
    352   B.SetInsertPoint(InnerLoopHeader->getTerminator());
    353   PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
    354   VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
    355 
    356   B.SetInsertPoint(InnerBody->getTerminator());
    357   Value *IdxA =
    358       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
    359   Value *IdxB =
    360       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
    361   Value *NewVecC = nullptr;
    362 
    363   if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
    364     // tiledpbssd.scalarize.inner.body:
    365     // calculate idxa, idxb
    366     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
    367     // %elta = extractelement <256 x i32> %veca, i16 %idxa
    368     // %eltav4i8 = bitcast i32 %elta to <4 x i8>
    369     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
    370     // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
    371     // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
    372     // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
    373     // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
    374     // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
    375     // %neweltc = add i32 %elt, %acc
    376     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
    377     // i16 %idxc
    378     FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
    379     FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
    380     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
    381     Value *EltA = B.CreateExtractElement(VecA, IdxA);
    382     Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
    383     Value *EltB = B.CreateExtractElement(VecB, IdxB);
    384     Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
    385     Value *SEXTSubVecB = nullptr;
    386     Value *SEXTSubVecA = nullptr;
    387     switch (IntrID) {
    388     case Intrinsic::x86_tdpbssd_internal:
    389       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
    390       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
    391       break;
    392     case Intrinsic::x86_tdpbsud_internal:
    393       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
    394       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
    395       break;
    396     case Intrinsic::x86_tdpbusd_internal:
    397       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
    398       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
    399       break;
    400     case Intrinsic::x86_tdpbuud_internal:
    401       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
    402       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
    403       break;
    404     default:
    405       llvm_unreachable("Invalid intrinsic ID!");
    406     }
    407     Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
    408     Value *ResElt = B.CreateAdd(EltC, SubVecR);
    409     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
    410   } else {
    411     // tiledpbf16ps.scalarize.inner.body:
    412     // calculate idxa, idxb, idxc
    413     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
    414     // %eltcf32 = bitcast i32 %eltc to float
    415     // %elta = extractelement <256 x i32> %veca, i16 %idxa
    416     // %eltav2i16 = bitcast i32 %elta to <2 x i16>
    417     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
    418     // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
    419     // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
    420     // x i32> <i32 2, i32 0, i32 3, i32 1>
    421     // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
    422     // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
    423     // i32> <i32 2, i32 0, i32 3, i32 1>
    424     // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
    425     // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
    426     // %acc = call float
    427     // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
    428     // %neweltc = bitcast float %acc to i32
    429     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
    430     // i16 %idxc
    431     // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
    432     // i16 %idxc
    433     FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
    434     FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
    435     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
    436     Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
    437     Value *EltA = B.CreateExtractElement(VecA, IdxA);
    438     Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
    439     Value *EltB = B.CreateExtractElement(VecB, IdxB);
    440     Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
    441     Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
    442     int ShuffleMask[4] = {2, 0, 3, 1};
    443     auto ShuffleArray = makeArrayRef(ShuffleMask);
    444     Value *AV2F32 = B.CreateBitCast(
    445         B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
    446     Value *BV2F32 = B.CreateBitCast(
    447         B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
    448     Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
    449     Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
    450     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
    451   }
    452 
    453   // tiledpbssd.scalarize.cols.latch:
    454   // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
    455   // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
    456   // i16 %idxc
    457   B.SetInsertPoint(ColLoopLatch->getTerminator());
    458   Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
    459   Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
    460 
    461   VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
    462   VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
    463   VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
    464   VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
    465   VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
    466 
    467   return NewVecD;
    468 }
    469 
    470 template <Intrinsic::ID IntrID>
    471 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
    472                             IntrID == Intrinsic::x86_tdpbsud_internal ||
    473                             IntrID == Intrinsic::x86_tdpbusd_internal ||
    474                             IntrID == Intrinsic::x86_tdpbuud_internal ||
    475                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
    476                         bool>::type
    477 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
    478   Value *M, *N, *K, *C, *A, *B;
    479   match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
    480                                     m_Value(C), m_Value(A), m_Value(B)));
    481   Instruction *InsertI = TileDP;
    482   IRBuilder<> PreBuilder(TileDP);
    483   PreBuilder.SetInsertPoint(TileDP);
    484   // We visit the loop with (m, n/4, k/4):
    485   // %n_dword = lshr i16 %n, 2
    486   // %k_dword = lshr i16 %k, 2
    487   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
    488   Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
    489   BasicBlock *Start = InsertI->getParent();
    490   BasicBlock *End =
    491       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
    492   IRBuilder<> Builder(TileDP);
    493   Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
    494                                             KDWord, C, A, B);
    495   // we cannot assume there always be bitcast after tiledpbssd. So we need to
    496   // insert one bitcast as required
    497   Builder.SetInsertPoint(End->getFirstNonPHI());
    498   Value *ResAMX =
    499       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
    500   // Delete TileDP intrinsic and do some clean-up.
    501   for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
    502     Instruction *I = cast<Instruction>((UI++)->getUser());
    503     Value *Vec;
    504     if (match(I, m_BitCast(m_Value(Vec)))) {
    505       I->replaceAllUsesWith(ResVec);
    506       I->eraseFromParent();
    507     }
    508   }
    509   TileDP->replaceAllUsesWith(ResAMX);
    510   TileDP->eraseFromParent();
    511   return true;
    512 }
    513 
    514 template <bool IsTileLoad>
    515 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
    516   Value *M, *N, *Ptr, *Stride, *Tile;
    517   if (IsTileLoad)
    518     match(TileLoadStore,
    519           m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
    520               m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
    521   else
    522     match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
    523                              m_Value(M), m_Value(N), m_Value(Ptr),
    524                              m_Value(Stride), m_Value(Tile)));
    525 
    526   Instruction *InsertI = TileLoadStore;
    527   IRBuilder<> PreBuilder(TileLoadStore);
    528   PreBuilder.SetInsertPoint(TileLoadStore);
    529   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
    530   Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
    531   BasicBlock *Start = InsertI->getParent();
    532   BasicBlock *End =
    533       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
    534   IRBuilder<> Builder(TileLoadStore);
    535   Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
    536       Start, End, Builder, M, NDWord, Ptr, StrideDWord,
    537       IsTileLoad ? nullptr : Tile);
    538   if (IsTileLoad) {
    539     // we cannot assume there always be bitcast after tileload. So we need to
    540     // insert one bitcast as required
    541     Builder.SetInsertPoint(End->getFirstNonPHI());
    542     Value *ResAMX =
    543         Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
    544     // Delete tileloadd6 intrinsic and do some clean-up
    545     for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
    546          UI != UE;) {
    547       Instruction *I = cast<Instruction>((UI++)->getUser());
    548       Value *Vec;
    549       if (match(I, m_BitCast(m_Value(Vec)))) {
    550         I->replaceAllUsesWith(ResVec);
    551         I->eraseFromParent();
    552       }
    553     }
    554     TileLoadStore->replaceAllUsesWith(ResAMX);
    555   }
    556   TileLoadStore->eraseFromParent();
    557   return true;
    558 }
    559 
    560 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
    561   IRBuilder<> Builder(TileZero);
    562   FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
    563   Value *VecZero = Constant::getNullValue(V256I32Ty);
    564   for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
    565     Instruction *I = cast<Instruction>((UI++)->getUser());
    566     Value *Vec;
    567     if (match(I, m_BitCast(m_Value(Vec)))) {
    568       I->replaceAllUsesWith(VecZero);
    569       I->eraseFromParent();
    570     }
    571   }
    572   TileZero->eraseFromParent();
    573   return true;
    574 }
    575 
    576 bool X86LowerAMXIntrinsics::visit() {
    577   bool C = false;
    578   SmallVector<IntrinsicInst *, 8> WorkList;
    579   for (BasicBlock *BB : depth_first(&Func)) {
    580     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
    581       if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
    582         switch (Inst->getIntrinsicID()) {
    583         case Intrinsic::x86_tdpbssd_internal:
    584         case Intrinsic::x86_tdpbsud_internal:
    585         case Intrinsic::x86_tdpbusd_internal:
    586         case Intrinsic::x86_tdpbuud_internal:
    587         case Intrinsic::x86_tileloadd64_internal:
    588         case Intrinsic::x86_tilestored64_internal:
    589         case Intrinsic::x86_tilezero_internal:
    590         case Intrinsic::x86_tdpbf16ps_internal:
    591           WorkList.push_back(Inst);
    592           break;
    593         default:
    594           break;
    595         }
    596       }
    597     }
    598   }
    599 
    600   for (auto *Inst : WorkList) {
    601     switch (Inst->getIntrinsicID()) {
    602     case Intrinsic::x86_tdpbssd_internal:
    603       C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
    604       break;
    605     case Intrinsic::x86_tdpbsud_internal:
    606       C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
    607       break;
    608     case Intrinsic::x86_tdpbusd_internal:
    609       C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
    610       break;
    611     case Intrinsic::x86_tdpbuud_internal:
    612       C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
    613       break;
    614     case Intrinsic::x86_tdpbf16ps_internal:
    615       C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
    616       break;
    617     case Intrinsic::x86_tileloadd64_internal:
    618       C = lowerTileLoadStore<true>(Inst) || C;
    619       break;
    620     case Intrinsic::x86_tilestored64_internal:
    621       C = lowerTileLoadStore<false>(Inst) || C;
    622       break;
    623     case Intrinsic::x86_tilezero_internal:
    624       C = lowerTileZero(Inst) || C;
    625       break;
    626     default:
    627       llvm_unreachable("invalid amx intrinsics!");
    628     }
    629   }
    630 
    631   return C;
    632 }
    633 
    634 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
    635 public:
    636   static char ID;
    637 
    638   X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
    639     initializeX86LowerAMXIntrinsicsLegacyPassPass(
    640         *PassRegistry::getPassRegistry());
    641   }
    642 
    643   bool runOnFunction(Function &F) override {
    644     if (!X86ScalarizeAMX)
    645       return false;
    646     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
    647     if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
    648         TM->getOptLevel() != CodeGenOpt::None)
    649       return false;
    650 
    651     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
    652     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
    653     auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
    654     auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
    655     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
    656 
    657     X86LowerAMXIntrinsics LAT(F, DTU, LI);
    658     return LAT.visit();
    659   }
    660   StringRef getPassName() const override { return "Lower AMX intrinsics"; }
    661 
    662   void getAnalysisUsage(AnalysisUsage &AU) const override {
    663     AU.addPreserved<DominatorTreeWrapperPass>();
    664     AU.addPreserved<LoopInfoWrapperPass>();
    665     AU.addRequired<TargetPassConfig>();
    666   }
    667 };
    668 
    669 static const char PassName[] = "Lower AMX intrinsics";
    670 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
    671 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
    672                       false, false)
    673 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
    674 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
    675                     false, false)
    676 
    677 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
    678   return new X86LowerAMXIntrinsicsLegacyPass();
    679 }
    680