Home | History | Annotate | Line # | Download | only in AMDGPU
      1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 /// \file
     10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
     11 /// sampler resource ID getter functions.
     12 ///
     13 /// Image attributes (size and format) are expected to be passed to the kernel
     14 /// as kernel arguments immediately following the image argument itself,
     15 /// therefore this pass adds image size and format arguments to the kernel
     16 /// functions in the module. The kernel functions with image arguments are
     17 /// re-created using the new signature. The new arguments are added to the
     18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
     19 /// Note: this pass may invalidate pointers to functions.
     20 ///
     21 /// Resource IDs of read-only images, write-only images and samplers are
     22 /// defined to be their index among the kernel arguments of the same
     23 /// type and access qualifier.
     24 //
     25 //===----------------------------------------------------------------------===//
     26 
     27 #include "AMDGPU.h"
     28 #include "llvm/ADT/SmallVector.h"
     29 #include "llvm/ADT/StringRef.h"
     30 #include "llvm/IR/Constants.h"
     31 #include "llvm/IR/Function.h"
     32 #include "llvm/IR/Instructions.h"
     33 #include "llvm/IR/Metadata.h"
     34 #include "llvm/Pass.h"
     35 #include "llvm/Transforms/Utils/Cloning.h"
     36 
     37 using namespace llvm;
     38 
     39 static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
     40 static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
     41 static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
     42 static StringRef GetSamplerResourceIDFunc =
     43     "llvm.OpenCL.sampler.get.resource.id";
     44 
     45 static StringRef ImageSizeArgMDType =   "__llvm_image_size";
     46 static StringRef ImageFormatArgMDType = "__llvm_image_format";
     47 
     48 static StringRef KernelsMDNodeName = "opencl.kernels";
     49 static StringRef KernelArgMDNodeNames[] = {
     50   "kernel_arg_addr_space",
     51   "kernel_arg_access_qual",
     52   "kernel_arg_type",
     53   "kernel_arg_base_type",
     54   "kernel_arg_type_qual"};
     55 static const unsigned NumKernelArgMDNodes = 5;
     56 
     57 namespace {
     58 
     59 using MDVector = SmallVector<Metadata *, 8>;
     60 struct KernelArgMD {
     61   MDVector ArgVector[NumKernelArgMDNodes];
     62 };
     63 
     64 } // end anonymous namespace
     65 
     66 static inline bool
     67 IsImageType(StringRef TypeString) {
     68   return TypeString == "image2d_t" || TypeString == "image3d_t";
     69 }
     70 
     71 static inline bool
     72 IsSamplerType(StringRef TypeString) {
     73   return TypeString == "sampler_t";
     74 }
     75 
     76 static Function *
     77 GetFunctionFromMDNode(MDNode *Node) {
     78   if (!Node)
     79     return nullptr;
     80 
     81   size_t NumOps = Node->getNumOperands();
     82   if (NumOps != NumKernelArgMDNodes + 1)
     83     return nullptr;
     84 
     85   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
     86   if (!F)
     87     return nullptr;
     88 
     89   // Sanity checks.
     90   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
     91   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
     92     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
     93     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
     94       return nullptr;
     95     if (!ArgNode->getOperand(0))
     96       return nullptr;
     97 
     98     // FIXME: It should be possible to do image lowering when some metadata
     99     // args missing or not in the expected order.
    100     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
    101     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
    102       return nullptr;
    103   }
    104 
    105   return F;
    106 }
    107 
    108 static StringRef
    109 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    110   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
    111   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
    112 }
    113 
    114 static StringRef
    115 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    116   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
    117   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
    118 }
    119 
    120 static MDVector
    121 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
    122   MDVector Res;
    123   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    124     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
    125     Res.push_back(Node->getOperand(OpIdx));
    126   }
    127   return Res;
    128 }
    129 
    130 static void
    131 PushArgMD(KernelArgMD &MD, const MDVector &V) {
    132   assert(V.size() == NumKernelArgMDNodes);
    133   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    134     MD.ArgVector[i].push_back(V[i]);
    135   }
    136 }
    137 
    138 namespace {
    139 
    140 class R600OpenCLImageTypeLoweringPass : public ModulePass {
    141   static char ID;
    142 
    143   LLVMContext *Context;
    144   Type *Int32Type;
    145   Type *ImageSizeType;
    146   Type *ImageFormatType;
    147   SmallVector<Instruction *, 4> InstsToErase;
    148 
    149   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
    150                         Argument &ImageSizeArg,
    151                         Argument &ImageFormatArg) {
    152     bool Modified = false;
    153 
    154     for (auto &Use : ImageArg.uses()) {
    155       auto Inst = dyn_cast<CallInst>(Use.getUser());
    156       if (!Inst) {
    157         continue;
    158       }
    159 
    160       Function *F = Inst->getCalledFunction();
    161       if (!F)
    162         continue;
    163 
    164       Value *Replacement = nullptr;
    165       StringRef Name = F->getName();
    166       if (Name.startswith(GetImageResourceIDFunc)) {
    167         Replacement = ConstantInt::get(Int32Type, ResourceID);
    168       } else if (Name.startswith(GetImageSizeFunc)) {
    169         Replacement = &ImageSizeArg;
    170       } else if (Name.startswith(GetImageFormatFunc)) {
    171         Replacement = &ImageFormatArg;
    172       } else {
    173         continue;
    174       }
    175 
    176       Inst->replaceAllUsesWith(Replacement);
    177       InstsToErase.push_back(Inst);
    178       Modified = true;
    179     }
    180 
    181     return Modified;
    182   }
    183 
    184   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
    185     bool Modified = false;
    186 
    187     for (const auto &Use : SamplerArg.uses()) {
    188       auto Inst = dyn_cast<CallInst>(Use.getUser());
    189       if (!Inst) {
    190         continue;
    191       }
    192 
    193       Function *F = Inst->getCalledFunction();
    194       if (!F)
    195         continue;
    196 
    197       Value *Replacement = nullptr;
    198       StringRef Name = F->getName();
    199       if (Name == GetSamplerResourceIDFunc) {
    200         Replacement = ConstantInt::get(Int32Type, ResourceID);
    201       } else {
    202         continue;
    203       }
    204 
    205       Inst->replaceAllUsesWith(Replacement);
    206       InstsToErase.push_back(Inst);
    207       Modified = true;
    208     }
    209 
    210     return Modified;
    211   }
    212 
    213   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
    214     uint32_t NumReadOnlyImageArgs = 0;
    215     uint32_t NumWriteOnlyImageArgs = 0;
    216     uint32_t NumSamplerArgs = 0;
    217 
    218     bool Modified = false;
    219     InstsToErase.clear();
    220     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
    221       Argument &Arg = *ArgI;
    222       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
    223 
    224       // Handle image types.
    225       if (IsImageType(Type)) {
    226         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
    227         uint32_t ResourceID;
    228         if (AccessQual == "read_only") {
    229           ResourceID = NumReadOnlyImageArgs++;
    230         } else if (AccessQual == "write_only") {
    231           ResourceID = NumWriteOnlyImageArgs++;
    232         } else {
    233           llvm_unreachable("Wrong image access qualifier.");
    234         }
    235 
    236         Argument &SizeArg = *(++ArgI);
    237         Argument &FormatArg = *(++ArgI);
    238         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
    239 
    240       // Handle sampler type.
    241       } else if (IsSamplerType(Type)) {
    242         uint32_t ResourceID = NumSamplerArgs++;
    243         Modified |= replaceSamplerUses(Arg, ResourceID);
    244       }
    245     }
    246     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
    247       InstsToErase[i]->eraseFromParent();
    248     }
    249 
    250     return Modified;
    251   }
    252 
    253   std::tuple<Function *, MDNode *>
    254   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
    255     bool Modified = false;
    256 
    257     FunctionType *FT = F->getFunctionType();
    258     SmallVector<Type *, 8> ArgTypes;
    259 
    260     // Metadata operands for new MDNode.
    261     KernelArgMD NewArgMDs;
    262     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
    263 
    264     // Add implicit arguments to the signature.
    265     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
    266       ArgTypes.push_back(FT->getParamType(i));
    267       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
    268       PushArgMD(NewArgMDs, ArgMD);
    269 
    270       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
    271         continue;
    272 
    273       // Add size implicit argument.
    274       ArgTypes.push_back(ImageSizeType);
    275       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
    276       PushArgMD(NewArgMDs, ArgMD);
    277 
    278       // Add format implicit argument.
    279       ArgTypes.push_back(ImageFormatType);
    280       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
    281       PushArgMD(NewArgMDs, ArgMD);
    282 
    283       Modified = true;
    284     }
    285     if (!Modified) {
    286       return std::make_tuple(nullptr, nullptr);
    287     }
    288 
    289     // Create function with new signature and clone the old body into it.
    290     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
    291     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
    292     ValueToValueMapTy VMap;
    293     auto NewFArgIt = NewF->arg_begin();
    294     for (auto &Arg: F->args()) {
    295       auto ArgName = Arg.getName();
    296       NewFArgIt->setName(ArgName);
    297       VMap[&Arg] = &(*NewFArgIt++);
    298       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
    299         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
    300         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
    301       }
    302     }
    303     SmallVector<ReturnInst*, 8> Returns;
    304     CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
    305                       Returns);
    306 
    307     // Build new MDNode.
    308     SmallVector<Metadata *, 6> KernelMDArgs;
    309     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
    310     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
    311       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
    312     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
    313 
    314     return std::make_tuple(NewF, NewMDNode);
    315   }
    316 
    317   bool transformKernels(Module &M) {
    318     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
    319     if (!KernelsMDNode)
    320       return false;
    321 
    322     bool Modified = false;
    323     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
    324       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
    325       Function *F = GetFunctionFromMDNode(KernelMDNode);
    326       if (!F)
    327         continue;
    328 
    329       Function *NewF;
    330       MDNode *NewMDNode;
    331       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
    332       if (NewF) {
    333         // Replace old function and metadata with new ones.
    334         F->eraseFromParent();
    335         M.getFunctionList().push_back(NewF);
    336         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
    337                               NewF->getAttributes());
    338         KernelsMDNode->setOperand(i, NewMDNode);
    339 
    340         F = NewF;
    341         KernelMDNode = NewMDNode;
    342         Modified = true;
    343       }
    344 
    345       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
    346     }
    347 
    348     return Modified;
    349   }
    350 
    351 public:
    352   R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
    353 
    354   bool runOnModule(Module &M) override {
    355     Context = &M.getContext();
    356     Int32Type = Type::getInt32Ty(M.getContext());
    357     ImageSizeType = ArrayType::get(Int32Type, 3);
    358     ImageFormatType = ArrayType::get(Int32Type, 2);
    359 
    360     return transformKernels(M);
    361   }
    362 
    363   StringRef getPassName() const override {
    364     return "R600 OpenCL Image Type Pass";
    365   }
    366 };
    367 
    368 } // end anonymous namespace
    369 
    370 char R600OpenCLImageTypeLoweringPass::ID = 0;
    371 
    372 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
    373   return new R600OpenCLImageTypeLoweringPass();
    374 }
    375