Home | History | Annotate | Line # | Download | only in IR
      1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file defines the MatrixBuilder class, which is used as a convenient way
     10 // to lower matrix operations to LLVM IR.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_IR_MATRIXBUILDER_H
     15 #define LLVM_IR_MATRIXBUILDER_H
     16 
     17 #include "llvm/IR/Constant.h"
     18 #include "llvm/IR/Constants.h"
     19 #include "llvm/IR/IRBuilder.h"
     20 #include "llvm/IR/InstrTypes.h"
     21 #include "llvm/IR/Instruction.h"
     22 #include "llvm/IR/IntrinsicInst.h"
     23 #include "llvm/IR/Type.h"
     24 #include "llvm/IR/Value.h"
     25 #include "llvm/Support/Alignment.h"
     26 
     27 namespace llvm {
     28 
     29 class Function;
     30 class Twine;
     31 class Module;
     32 
     33 template <class IRBuilderTy> class MatrixBuilder {
     34   IRBuilderTy &B;
     35   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
     36 
     37   std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
     38                                                          Value *RHS) {
     39     assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
     40            "One of the operands must be a matrix (embedded in a vector)");
     41     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
     42       assert(!isa<ScalableVectorType>(LHS->getType()) &&
     43              "LHS Assumed to be fixed width");
     44       RHS = B.CreateVectorSplat(
     45           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
     46           "scalar.splat");
     47     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
     48       assert(!isa<ScalableVectorType>(RHS->getType()) &&
     49              "RHS Assumed to be fixed width");
     50       LHS = B.CreateVectorSplat(
     51           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
     52           "scalar.splat");
     53     }
     54     return {LHS, RHS};
     55   }
     56 
     57 public:
     58   MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
     59 
     60   /// Create a column major, strided matrix load.
     61   /// \p DataPtr - Start address of the matrix read
     62   /// \p Rows    - Number of rows in matrix (must be a constant)
     63   /// \p Columns - Number of columns in matrix (must be a constant)
     64   /// \p Stride  - Space between columns
     65   CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
     66                                   Value *Stride, bool IsVolatile, unsigned Rows,
     67                                   unsigned Columns, const Twine &Name = "") {
     68 
     69     // Deal with the pointer
     70     PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
     71     Type *EltTy = PtrTy->getElementType();
     72 
     73     auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
     74 
     75     Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
     76                     B.getInt32(Columns)};
     77     Type *OverloadedTypes[] = {RetType};
     78 
     79     Function *TheFn = Intrinsic::getDeclaration(
     80         getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
     81 
     82     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
     83     Attribute AlignAttr =
     84         Attribute::getWithAlignment(Call->getContext(), Alignment);
     85     Call->addAttribute(1, AlignAttr);
     86     return Call;
     87   }
     88 
     89   /// Create a column major, strided matrix store.
     90   /// \p Matrix  - Matrix to store
     91   /// \p Ptr     - Pointer to write back to
     92   /// \p Stride  - Space between columns
     93   CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
     94                                    Value *Stride, bool IsVolatile,
     95                                    unsigned Rows, unsigned Columns,
     96                                    const Twine &Name = "") {
     97     Value *Ops[] = {Matrix,           Ptr,
     98                     Stride,           B.getInt1(IsVolatile),
     99                     B.getInt32(Rows), B.getInt32(Columns)};
    100     Type *OverloadedTypes[] = {Matrix->getType()};
    101 
    102     Function *TheFn = Intrinsic::getDeclaration(
    103         getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
    104 
    105     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
    106     Attribute AlignAttr =
    107         Attribute::getWithAlignment(Call->getContext(), Alignment);
    108     Call->addAttribute(2, AlignAttr);
    109     return Call;
    110   }
    111 
    112   /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
    113   /// rows and \p Columns columns.
    114   CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
    115                                   unsigned Columns, const Twine &Name = "") {
    116     auto *OpType = cast<VectorType>(Matrix->getType());
    117     auto *ReturnType =
    118         FixedVectorType::get(OpType->getElementType(), Rows * Columns);
    119 
    120     Type *OverloadedTypes[] = {ReturnType};
    121     Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
    122     Function *TheFn = Intrinsic::getDeclaration(
    123         getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
    124 
    125     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
    126   }
    127 
    128   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
    129   /// RHS.
    130   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
    131                                  unsigned LHSColumns, unsigned RHSColumns,
    132                                  const Twine &Name = "") {
    133     auto *LHSType = cast<VectorType>(LHS->getType());
    134     auto *RHSType = cast<VectorType>(RHS->getType());
    135 
    136     auto *ReturnType =
    137         FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
    138 
    139     Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
    140                     B.getInt32(RHSColumns)};
    141     Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
    142 
    143     Function *TheFn = Intrinsic::getDeclaration(
    144         getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
    145     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
    146   }
    147 
    148   /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
    149   /// ColumnIdx).
    150   Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
    151                             Value *ColumnIdx, unsigned NumRows) {
    152     return B.CreateInsertElement(
    153         Matrix, NewVal,
    154         B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
    155                                                ColumnIdx->getType(), NumRows)),
    156                     RowIdx));
    157   }
    158 
    159   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
    160   /// matrixes.
    161   Value *CreateAdd(Value *LHS, Value *RHS) {
    162     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
    163     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
    164       assert(!isa<ScalableVectorType>(LHS->getType()) &&
    165              "LHS Assumed to be fixed width");
    166       RHS = B.CreateVectorSplat(
    167           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
    168           "scalar.splat");
    169     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
    170       assert(!isa<ScalableVectorType>(RHS->getType()) &&
    171              "RHS Assumed to be fixed width");
    172       LHS = B.CreateVectorSplat(
    173           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
    174           "scalar.splat");
    175     }
    176 
    177     return cast<VectorType>(LHS->getType())
    178                    ->getElementType()
    179                    ->isFloatingPointTy()
    180                ? B.CreateFAdd(LHS, RHS)
    181                : B.CreateAdd(LHS, RHS);
    182   }
    183 
    184   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
    185   /// point matrixes.
    186   Value *CreateSub(Value *LHS, Value *RHS) {
    187     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
    188     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
    189       assert(!isa<ScalableVectorType>(LHS->getType()) &&
    190              "LHS Assumed to be fixed width");
    191       RHS = B.CreateVectorSplat(
    192           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
    193           "scalar.splat");
    194     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
    195       assert(!isa<ScalableVectorType>(RHS->getType()) &&
    196              "RHS Assumed to be fixed width");
    197       LHS = B.CreateVectorSplat(
    198           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
    199           "scalar.splat");
    200     }
    201 
    202     return cast<VectorType>(LHS->getType())
    203                    ->getElementType()
    204                    ->isFloatingPointTy()
    205                ? B.CreateFSub(LHS, RHS)
    206                : B.CreateSub(LHS, RHS);
    207   }
    208 
    209   /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
    210   /// RHS.
    211   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
    212     std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
    213     if (LHS->getType()->getScalarType()->isFloatingPointTy())
    214       return B.CreateFMul(LHS, RHS);
    215     return B.CreateMul(LHS, RHS);
    216   }
    217 
    218   /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
    219   /// IsUnsigned indicates whether UDiv or SDiv should be used.
    220   Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
    221     assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
    222     assert(!isa<ScalableVectorType>(LHS->getType()) &&
    223            "LHS Assumed to be fixed width");
    224     RHS =
    225         B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
    226                             RHS, "scalar.splat");
    227     return cast<VectorType>(LHS->getType())
    228                    ->getElementType()
    229                    ->isFloatingPointTy()
    230                ? B.CreateFDiv(LHS, RHS)
    231                : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
    232   }
    233 
    234   /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
    235   Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx,
    236                               unsigned NumRows, Twine const &Name = "") {
    237 
    238     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
    239                                  ColumnIdx->getType()->getScalarSizeInBits());
    240     Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
    241     RowIdx = B.CreateZExt(RowIdx, IntTy);
    242     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
    243     Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
    244     return B.CreateExtractElement(
    245         Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx),
    246         "matext");
    247   }
    248 };
    249 
    250 } // end namespace llvm
    251 
    252 #endif // LLVM_IR_MATRIXBUILDER_H
    253