Home | History | Annotate | Line # | Download | only in NVPTX
      1 //===- NVVMIntrRange.cpp - Set !range metadata for NVVM intrinsics --------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This pass adds appropriate !range metadata for calls to NVVM
     10 // intrinsics that return a limited range of values.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "NVPTX.h"
     15 #include "llvm/IR/Constants.h"
     16 #include "llvm/IR/InstIterator.h"
     17 #include "llvm/IR/Instructions.h"
     18 #include "llvm/IR/Intrinsics.h"
     19 #include "llvm/IR/IntrinsicsNVPTX.h"
     20 #include "llvm/IR/PassManager.h"
     21 #include "llvm/Support/CommandLine.h"
     22 
     23 using namespace llvm;
     24 
     25 #define DEBUG_TYPE "nvvm-intr-range"
     26 
     27 namespace llvm { void initializeNVVMIntrRangePass(PassRegistry &); }
     28 
     29 // Add !range metadata based on limits of given SM variant.
     30 static cl::opt<unsigned> NVVMIntrRangeSM("nvvm-intr-range-sm", cl::init(20),
     31                                          cl::Hidden, cl::desc("SM variant"));
     32 
     33 namespace {
     34 class NVVMIntrRange : public FunctionPass {
     35  private:
     36    unsigned SmVersion;
     37 
     38  public:
     39    static char ID;
     40    NVVMIntrRange() : NVVMIntrRange(NVVMIntrRangeSM) {}
     41    NVVMIntrRange(unsigned int SmVersion)
     42        : FunctionPass(ID), SmVersion(SmVersion) {
     43 
     44      initializeNVVMIntrRangePass(*PassRegistry::getPassRegistry());
     45    }
     46 
     47    bool runOnFunction(Function &) override;
     48 };
     49 }
     50 
     51 FunctionPass *llvm::createNVVMIntrRangePass(unsigned int SmVersion) {
     52   return new NVVMIntrRange(SmVersion);
     53 }
     54 
     55 char NVVMIntrRange::ID = 0;
     56 INITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
     57                 "Add !range metadata to NVVM intrinsics.", false, false)
     58 
     59 // Adds the passed-in [Low,High) range information as metadata to the
     60 // passed-in call instruction.
     61 static bool addRangeMetadata(uint64_t Low, uint64_t High, CallInst *C) {
     62   // This call already has range metadata, nothing to do.
     63   if (C->getMetadata(LLVMContext::MD_range))
     64     return false;
     65 
     66   LLVMContext &Context = C->getParent()->getContext();
     67   IntegerType *Int32Ty = Type::getInt32Ty(Context);
     68   Metadata *LowAndHigh[] = {
     69       ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Low)),
     70       ConstantAsMetadata::get(ConstantInt::get(Int32Ty, High))};
     71   C->setMetadata(LLVMContext::MD_range, MDNode::get(Context, LowAndHigh));
     72   return true;
     73 }
     74 
     75 static bool runNVVMIntrRange(Function &F, unsigned SmVersion) {
     76   struct {
     77     unsigned x, y, z;
     78   } MaxBlockSize, MaxGridSize;
     79   MaxBlockSize.x = 1024;
     80   MaxBlockSize.y = 1024;
     81   MaxBlockSize.z = 64;
     82 
     83   MaxGridSize.x = SmVersion >= 30 ? 0x7fffffff : 0xffff;
     84   MaxGridSize.y = 0xffff;
     85   MaxGridSize.z = 0xffff;
     86 
     87   // Go through the calls in this function.
     88   bool Changed = false;
     89   for (Instruction &I : instructions(F)) {
     90     CallInst *Call = dyn_cast<CallInst>(&I);
     91     if (!Call)
     92       continue;
     93 
     94     if (Function *Callee = Call->getCalledFunction()) {
     95       switch (Callee->getIntrinsicID()) {
     96       // Index within block
     97       case Intrinsic::nvvm_read_ptx_sreg_tid_x:
     98         Changed |= addRangeMetadata(0, MaxBlockSize.x, Call);
     99         break;
    100       case Intrinsic::nvvm_read_ptx_sreg_tid_y:
    101         Changed |= addRangeMetadata(0, MaxBlockSize.y, Call);
    102         break;
    103       case Intrinsic::nvvm_read_ptx_sreg_tid_z:
    104         Changed |= addRangeMetadata(0, MaxBlockSize.z, Call);
    105         break;
    106 
    107       // Block size
    108       case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
    109         Changed |= addRangeMetadata(1, MaxBlockSize.x+1, Call);
    110         break;
    111       case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
    112         Changed |= addRangeMetadata(1, MaxBlockSize.y+1, Call);
    113         break;
    114       case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
    115         Changed |= addRangeMetadata(1, MaxBlockSize.z+1, Call);
    116         break;
    117 
    118       // Index within grid
    119       case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
    120         Changed |= addRangeMetadata(0, MaxGridSize.x, Call);
    121         break;
    122       case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
    123         Changed |= addRangeMetadata(0, MaxGridSize.y, Call);
    124         break;
    125       case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
    126         Changed |= addRangeMetadata(0, MaxGridSize.z, Call);
    127         break;
    128 
    129       // Grid size
    130       case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
    131         Changed |= addRangeMetadata(1, MaxGridSize.x+1, Call);
    132         break;
    133       case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
    134         Changed |= addRangeMetadata(1, MaxGridSize.y+1, Call);
    135         break;
    136       case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
    137         Changed |= addRangeMetadata(1, MaxGridSize.z+1, Call);
    138         break;
    139 
    140       // warp size is constant 32.
    141       case Intrinsic::nvvm_read_ptx_sreg_warpsize:
    142         Changed |= addRangeMetadata(32, 32+1, Call);
    143         break;
    144 
    145       // Lane ID is [0..warpsize)
    146       case Intrinsic::nvvm_read_ptx_sreg_laneid:
    147         Changed |= addRangeMetadata(0, 32, Call);
    148         break;
    149 
    150       default:
    151         break;
    152       }
    153     }
    154   }
    155 
    156   return Changed;
    157 }
    158 
    159 bool NVVMIntrRange::runOnFunction(Function &F) {
    160   return runNVVMIntrRange(F, SmVersion);
    161 }
    162 
    163 NVVMIntrRangePass::NVVMIntrRangePass() : NVVMIntrRangePass(NVVMIntrRangeSM) {}
    164 
    165 PreservedAnalyses NVVMIntrRangePass::run(Function &F,
    166                                          FunctionAnalysisManager &AM) {
    167   return runNVVMIntrRange(F, SmVersion) ? PreservedAnalyses::none()
    168                                         : PreservedAnalyses::all();
    169 }
    170