Home | History | Annotate | Line # | Download | only in X86
      1 //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 /// Insert tilecfg for each area of key AMX intrinsic.
     10 /// All the key AMX intrinsic's tile operand must come from tileload. And the
     11 /// def tile of key AMX intrinsic must be tilestored.
     12 /// take tdpbssd for example:
     13 /// --------------------------------------------------------------------------
     14 /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...)                key
     15 /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...)                 |
     16 /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...)                amx
     17 /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3)         |
     18 /// call void @llvm.x86.tilestored64.internal(... td)                     area
     19 /// --------------------------------------------------------------------------
     20 /// This pass will insert tilecfg before every key-amx-area, some like:
     21 /// --------------------------------------------------------------------------
     22 /// %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
     23 /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
     24 /// ...
     25 /// ... pre-config shape of %t1                                 *
     26 /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
     27 /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
     28 /// ...                                                         *
     29 /// ... pre-config shape of %t2                                 * shapes
     30 /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     *
     31 /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
     32 /// ...
     33 /// call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * tile config
     34 //
     35 //===----------------------------------------------------------------------===//
     36 //
     37 #include "X86.h"
     38 #include "llvm/ADT/SmallSet.h"
     39 #include "llvm/Analysis/TargetTransformInfo.h"
     40 #include "llvm/CodeGen/Passes.h"
     41 #include "llvm/CodeGen/TargetPassConfig.h"
     42 #include "llvm/CodeGen/ValueTypes.h"
     43 #include "llvm/IR/DataLayout.h"
     44 #include "llvm/IR/Function.h"
     45 #include "llvm/IR/IRBuilder.h"
     46 #include "llvm/IR/Instructions.h"
     47 #include "llvm/IR/IntrinsicInst.h"
     48 #include "llvm/IR/IntrinsicsX86.h"
     49 #include "llvm/IR/PatternMatch.h"
     50 #include "llvm/InitializePasses.h"
     51 #include "llvm/Pass.h"
     52 #include "llvm/Support/raw_ostream.h"
     53 #include "llvm/Target/TargetMachine.h"
     54 
     55 using namespace llvm;
     56 using namespace PatternMatch;
     57 
     58 #define DEBUG_TYPE "pre-amx-config"
     59 
     60 static bool isAMXIntrinsic(IntrinsicInst *II) {
     61   for (Value *Operand : II->operands())
     62     if (Operand->getType()->isX86_AMXTy())
     63       return true;
     64   return II->getType()->isX86_AMXTy();
     65 }
     66 
     67 static bool isTileLoad(IntrinsicInst *II) {
     68   return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal;
     69 }
     70 
     71 static bool isTileStore(IntrinsicInst *II) {
     72   return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
     73 }
     74 
     75 #ifndef NDEBUG
     76 static bool onlyTileDef(IntrinsicInst *II) {
     77   for (Value *Operand : II->operands())
     78     if (Operand->getType()->isX86_AMXTy())
     79       return false;
     80   return II->getType()->isX86_AMXTy();
     81 }
     82 
     83 static bool brokenVolatile(Instruction *I) {
     84   // Todo: it is weak to identify a normal call here.
     85   if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
     86     return true;
     87   return false;
     88 }
     89 #endif
     90 
     91 namespace {
     92 class X86PreAMXConfig {
     93   Function &F;
     94 
     95 public:
     96   X86PreAMXConfig(Function &Func) : F(Func) {}
     97   bool preTileConfig();
     98   bool addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
     99   bool findConfigShapes(
    100       DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes);
    101   bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
    102   bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
    103                        SmallVector<Value *, 8> &Shapes);
    104   BasicBlock::iterator
    105   getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
    106                            SmallVector<Value *, 8> &Shapes);
    107   bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
    108                           IntrinsicInst *KeyAMX);
    109 };
    110 
    111 // Orderly write the shapes in tilecfg's mem. This maybe not right.
    112 // Because the first shape may not corresponding to the first tmm register,
    113 // so we need to handle at at X86FastTileConfig::materializeTileCfg()
    114 // after register allocation.
    115 // For example:
    116 // --------------------------------------------------------------------------
    117 // zeroinitialize tilecfg's mem (of ldtilecfg)
    118 // --------------------------------------------------------------------------
    119 // ... pre-config shape of %t1                                 *
    120 // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48   *
    121 // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
    122 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
    123 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
    124 // ...                                                         *
    125 // ... pre-config shape of %t2                                 *
    126 // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49   *
    127 // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
    128 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
    129 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
    130 // ...                                                         *
    131 // ... pre-config shape of %t3                                 * of
    132 // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50   *
    133 // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
    134 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
    135 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
    136 // ...                                                         * tiles
    137 // ... pre-config shape of %td                                 *
    138 // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51   *
    139 // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
    140 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
    141 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
    142 // --------------------------------------------------------------------------
    143 // call void @llvm.x86.ldtilecfg(i8* %mem)                     * tile config
    144 // --------------------------------------------------------------------------
    145 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
    146 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
    147 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
    148 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
    149 // call void @llvm.x86.tilestored64.internal(... td)                     area
    150 // --------------------------------------------------------------------------
    151 bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
    152                                       SmallVector<Value *, 8> &Shapes) {
    153   bool Write = false;
    154   LLVMContext &Ctx = Pos->getParent()->getContext();
    155   Type *I8Ty = Type::getInt8Ty(Ctx);
    156   Type *I16Ty = Type::getInt16Ty(Ctx);
    157 
    158   // TODO: Currently we defaultly set Palette = 1, it may be assigned to
    159   // other value in the future.
    160   Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
    161   Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
    162   Value *PalettePos =
    163       GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos);
    164   new StoreInst(PaletteValue, PalettePos, Pos);
    165 
    166   for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
    167     Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
    168     Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
    169     const std::string ShapeName = "amx.tmm." + itostr(I);
    170     Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset,
    171                                               ShapeName + ".shape.row", Pos);
    172     Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos);
    173     ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0),
    174                              ShapeName + ".shape.col", Pos);
    175     Value *Row = Shapes[I * 2];
    176     Value *Col = Shapes[I * 2 + 1];
    177     Row = new TruncInst(Row, I8Ty, "", Pos);
    178     new StoreInst(Row, RowPos, Pos);
    179     new StoreInst(Col, ColPos, Pos);
    180     Write = true;
    181   }
    182   return Write;
    183 }
    184 
    185 bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
    186                                     SmallVector<Value *, 8> &Shapes) {
    187   Module *M = F.getParent();
    188   IRBuilder<> Builder(ModelStart);
    189   const DataLayout &DL = M->getDataLayout();
    190   unsigned AddrSpace = DL.getAllocaAddrSpace();
    191   LLVMContext &Ctx = Builder.getContext();
    192   Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
    193   Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
    194 
    195   AllocaInst *Addr =
    196       new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
    197   Addr->setAlignment(Alignment);
    198   Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
    199 
    200   std::array<Value *, 1> Args = {I8Ptr};
    201   Instruction *Cfg =
    202       Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args);
    203 
    204   Value *Val0 = Constant::getNullValue(V512Ty);
    205   Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg);
    206   assert(Init0 && "Not Zero initilizate the cfg mem!");
    207 
    208   preWriteTileCfg(I8Ptr, Cfg, Shapes);
    209 
    210   return Init0;
    211 }
    212 
    213 // Todo: We may need to handle "more than one store" case in the future.
    214 bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
    215                                          IntrinsicInst *Store,
    216                                          IntrinsicInst *KeyAMX) {
    217   Value *ST = Store->getOperand(4);
    218 
    219   // Only has tileload and tilestore.
    220   if (!KeyAMX)
    221     return (Loads.size() == 1) && Loads.contains(ST);
    222 
    223   // All Loads should be operands of KeyAMX.
    224   // All tile operands of KeyAMX should come from Loads.
    225   for (Value *Op : KeyAMX->operands()) {
    226     if (Op->getType()->isX86_AMXTy())
    227       if (!Loads.erase(Op))
    228         return false;
    229   }
    230 
    231   // The def of KeyAMX should be stored into mem.
    232   // Todo: is it key amx can be no def?
    233   return Loads.empty() && (ST == cast<Value>(KeyAMX));
    234 }
    235 
    236 bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
    237                                       SmallVector<Value *, 8> &Shapes) {
    238   for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
    239     Value *Op = KeyAMX->getOperand(I);
    240     if (!Op->getType()->isX86_AMXTy())
    241       continue;
    242     IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
    243     assert((TileDef && isTileLoad(TileDef)) &&
    244            "All KeyAMX's tile definiation should comes from TileLoad!");
    245     Shapes.push_back(TileDef->getOperand(0));
    246     Shapes.push_back(TileDef->getOperand(1));
    247   }
    248   if (!isTileStore(KeyAMX)) {
    249     Shapes.push_back(KeyAMX->getOperand(0));
    250     Shapes.push_back(KeyAMX->getOperand(1));
    251   }
    252   return Shapes.size() != 0;
    253 }
    254 
    255 // Collect the shapes and skip the area of current key amx intrinsic.
    256 //
    257 // For example:
    258 // ...
    259 // --------------------------------------------------------------------------
    260 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)  record (m,k)
    261 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)  record (m,k)
    262 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)  record (m,k)
    263 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
    264 // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
    265 // --------------------------------------------------------------------------
    266 BasicBlock::iterator
    267 X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
    268                                           SmallVector<Value *, 8> &Shapes) {
    269   IntrinsicInst *KeyAMX = nullptr;
    270   BasicBlock *BB = Iter->getParent();
    271   BasicBlock::iterator PosEnd = BB->end();
    272   SmallSet<Value *, 4> Loads;
    273 
    274   // See TileStore as "Config Position End" and check volatile model.
    275   for (auto I = Iter, E = BB->end(); I != E; ++I) {
    276     assert(!brokenVolatile(&*I) && "Not reach tile store!");
    277     IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
    278     if (!II || !isAMXIntrinsic(II))
    279       continue;
    280 
    281     if (isTileLoad(II)) {
    282       Loads.insert(II);
    283     } else if (isTileStore(II)) {
    284       if (!checkVolatileModel(Loads, II, KeyAMX))
    285         report_fatal_error("Not Volatile AMX Model!");
    286       PosEnd = I;
    287       break;
    288     } else {
    289       assert(!KeyAMX && "Too many key amx intrinsic!");
    290       KeyAMX = II;
    291     }
    292   }
    293   assert(PosEnd != BB->end() && "Not find TileStore!");
    294 
    295   // See KeyAMX as TileStore if only TileLoad and TileStore.
    296   if (!KeyAMX)
    297     KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
    298 
    299   // Get Shapes in order.
    300   assert(Shapes.empty() && "Shapes should be clean.");
    301   getKeyAMXShapes(KeyAMX, Shapes);
    302 
    303   return PosEnd;
    304 }
    305 
    306 // Record a key amx area's shapes with its position.
    307 // Use the first tileload as its position.
    308 // For example:
    309 // ...
    310 // --------------------------------------------------------------------------
    311 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)   <--  pos
    312 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)        /
    313 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)     shapes:
    314 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)    (m,k)(k,n)
    315 // call void @llvm.x86.tilestored64.internal(m, n,... td)          (m,n)(m,n)
    316 // --------------------------------------------------------------------------
    317 bool X86PreAMXConfig::findConfigShapes(
    318     DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes) {
    319   bool Find = false;
    320   for (BasicBlock &BB : F) {
    321     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
    322       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
    323       if (!II)
    324         continue;
    325       if (!isAMXIntrinsic(II))
    326         continue;
    327       assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
    328 
    329       I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
    330       Find = true;
    331     }
    332   }
    333   return Find;
    334 }
    335 
    336 // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
    337 // e.g. (key amx = tdpbssd)
    338 // --------------------------------------------------------------------------
    339 // %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
    340 // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
    341 // ...
    342 // ... pre-config shape of %t1                                 *
    343 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
    344 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
    345 // ...                                                         *
    346 // ... pre-config shape of %t2                                 *
    347 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
    348 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
    349 // ...                                                         *
    350 // ... pre-config shape of %t3                                 * of
    351 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
    352 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
    353 // ...                                                         * tiles
    354 // ... pre-config shape of %td                                 *
    355 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
    356 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
    357 //
    358 // call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * pre-config
    359 // --------------------------------------------------------------------------
    360 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
    361 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
    362 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
    363 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
    364 // call void @llvm.x86.tilestored64.internal(... td)                     area
    365 // --------------------------------------------------------------------------
    366 bool X86PreAMXConfig::preTileConfig() {
    367   DenseMap<Instruction *, SmallVector<Value *, 8>> PosAndShapes;
    368   bool NeedCfg = findConfigShapes(PosAndShapes);
    369   if (!NeedCfg)
    370     return false;
    371   for (auto &IPAndShapes : PosAndShapes)
    372     addTileConfig(IPAndShapes.first, IPAndShapes.second);
    373 
    374   return true;
    375 }
    376 } // anonymous namespace
    377 
    378 namespace {
    379 
    380 class X86PreAMXConfigPass : public FunctionPass {
    381 public:
    382   static char ID;
    383 
    384   X86PreAMXConfigPass() : FunctionPass(ID) {
    385     initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
    386   }
    387 
    388   bool runOnFunction(Function &F) override {
    389     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
    390     bool C = false;
    391 
    392     // Prepare for fast register allocation at O0.
    393     if (TM->getOptLevel() == CodeGenOpt::None) {
    394 
    395       // We pre-config each key AMX intrinsic at O0.
    396       // In theory, one tile config can cover several AMX intrinsics, but
    397       // it is very diffcult to classify the tile shapes at O0. So here we
    398       // let thing be easy, pre-config every key AMX intrinsic.
    399       X86PreAMXConfig PCFG(F);
    400       C = PCFG.preTileConfig();
    401     }
    402 
    403     return C;
    404   }
    405 
    406   void getAnalysisUsage(AnalysisUsage &AU) const override {
    407     AU.setPreservesCFG();
    408     AU.addRequired<TargetPassConfig>();
    409   }
    410 };
    411 
    412 } // anonymous namespace
    413 
    414 static const char PassName[] = "Pre AMX Tile Config";
    415 char X86PreAMXConfigPass::ID = 0;
    416 INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
    417 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
    418 INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
    419 
    420 FunctionPass *llvm::createX86PreAMXConfigPass() {
    421   return new X86PreAMXConfigPass();
    422 }
    423