Home | History | Annotate | Line # | Download | only in Utils
      1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // Utility function to lower a printf call into a series of device
     10 // library calls on the AMDGPU target.
     11 //
     12 // WARNING: This file knows about certain library functions. It recognizes them
     13 // by name, and hardwires knowledge of their semantics.
     14 //
     15 //===----------------------------------------------------------------------===//
     16 
     17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
     18 #include "llvm/ADT/SparseBitVector.h"
     19 #include "llvm/Analysis/ValueTracking.h"
     20 
     21 using namespace llvm;
     22 
     23 #define DEBUG_TYPE "amdgpu-emit-printf"
     24 
     25 static bool isCString(const Value *Arg) {
     26   auto Ty = Arg->getType();
     27   auto PtrTy = dyn_cast<PointerType>(Ty);
     28   if (!PtrTy)
     29     return false;
     30 
     31   auto IntTy = dyn_cast<IntegerType>(PtrTy->getElementType());
     32   if (!IntTy)
     33     return false;
     34 
     35   return IntTy->getBitWidth() == 8;
     36 }
     37 
     38 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {
     39   auto Int64Ty = Builder.getInt64Ty();
     40   auto Ty = Arg->getType();
     41 
     42   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
     43     switch (IntTy->getBitWidth()) {
     44     case 32:
     45       return Builder.CreateZExt(Arg, Int64Ty);
     46     case 64:
     47       return Arg;
     48     }
     49   }
     50 
     51   if (Ty->getTypeID() == Type::DoubleTyID) {
     52     return Builder.CreateBitCast(Arg, Int64Ty);
     53   }
     54 
     55   if (isa<PointerType>(Ty)) {
     56     return Builder.CreatePtrToInt(Arg, Int64Ty);
     57   }
     58 
     59   llvm_unreachable("unexpected type");
     60 }
     61 
     62 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {
     63   auto Int64Ty = Builder.getInt64Ty();
     64   auto M = Builder.GetInsertBlock()->getModule();
     65   auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);
     66   return Builder.CreateCall(Fn, Version);
     67 }
     68 
     69 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,
     70                              Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,
     71                              Value *Arg4, Value *Arg5, Value *Arg6,
     72                              bool IsLast) {
     73   auto Int64Ty = Builder.getInt64Ty();
     74   auto Int32Ty = Builder.getInt32Ty();
     75   auto M = Builder.GetInsertBlock()->getModule();
     76   auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,
     77                                    Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,
     78                                    Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);
     79   auto IsLastValue = Builder.getInt32(IsLast);
     80   auto NumArgsValue = Builder.getInt32(NumArgs);
     81   return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,
     82                                  Arg4, Arg5, Arg6, IsLastValue});
     83 }
     84 
     85 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
     86                         bool IsLast) {
     87   auto Arg0 = fitArgInto64Bits(Builder, Arg);
     88   auto Zero = Builder.getInt64(0);
     89   return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,
     90                         Zero, IsLast);
     91 }
     92 
     93 // The device library does not provide strlen, so we build our own loop
     94 // here. While we are at it, we also include the terminating null in the length.
     95 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
     96   auto *Prev = Builder.GetInsertBlock();
     97   Module *M = Prev->getModule();
     98 
     99   auto CharZero = Builder.getInt8(0);
    100   auto One = Builder.getInt64(1);
    101   auto Zero = Builder.getInt64(0);
    102   auto Int64Ty = Builder.getInt64Ty();
    103 
    104   // The length is either zero for a null pointer, or the computed value for an
    105   // actual string. We need a join block for a phi that represents the final
    106   // value.
    107   //
    108   //  Strictly speaking, the zero does not matter since
    109   // __ockl_printf_append_string_n ignores the length if the pointer is null.
    110   BasicBlock *Join = nullptr;
    111   if (Prev->getTerminator()) {
    112     Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),
    113                                  "strlen.join");
    114     Prev->getTerminator()->eraseFromParent();
    115   } else {
    116     Join = BasicBlock::Create(M->getContext(), "strlen.join",
    117                               Prev->getParent());
    118   }
    119   BasicBlock *While =
    120       BasicBlock::Create(M->getContext(), "strlen.while",
    121                          Prev->getParent(), Join);
    122   BasicBlock *WhileDone = BasicBlock::Create(
    123       M->getContext(), "strlen.while.done",
    124       Prev->getParent(), Join);
    125 
    126   // Emit an early return for when the pointer is null.
    127   Builder.SetInsertPoint(Prev);
    128   auto CmpNull =
    129       Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));
    130   BranchInst::Create(Join, While, CmpNull, Prev);
    131 
    132   // Entry to the while loop.
    133   Builder.SetInsertPoint(While);
    134 
    135   auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);
    136   PtrPhi->addIncoming(Str, Prev);
    137   auto PtrNext = Builder.CreateGEP(PtrPhi, One);
    138   PtrPhi->addIncoming(PtrNext, While);
    139 
    140   // Condition for the while loop.
    141   auto Data = Builder.CreateLoad(Builder.getInt8Ty(), PtrPhi);
    142   auto Cmp = Builder.CreateICmpEQ(Data, CharZero);
    143   Builder.CreateCondBr(Cmp, WhileDone, While);
    144 
    145   // Add one to the computed length.
    146   Builder.SetInsertPoint(WhileDone, WhileDone->begin());
    147   auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);
    148   auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);
    149   auto Len = Builder.CreateSub(End, Begin);
    150   Len = Builder.CreateAdd(Len, One);
    151 
    152   // Final join.
    153   BranchInst::Create(Join, WhileDone);
    154   Builder.SetInsertPoint(Join, Join->begin());
    155   auto LenPhi = Builder.CreatePHI(Len->getType(), 2);
    156   LenPhi->addIncoming(Len, WhileDone);
    157   LenPhi->addIncoming(Zero, Prev);
    158 
    159   return LenPhi;
    160 }
    161 
    162 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
    163                                 Value *Length, bool isLast) {
    164   auto Int64Ty = Builder.getInt64Ty();
    165   auto CharPtrTy = Builder.getInt8PtrTy();
    166   auto Int32Ty = Builder.getInt32Ty();
    167   auto M = Builder.GetInsertBlock()->getModule();
    168   auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
    169                                    Int64Ty, CharPtrTy, Int64Ty, Int32Ty);
    170   auto IsLastInt32 = Builder.getInt32(isLast);
    171   return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
    172 }
    173 
    174 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
    175                            bool IsLast) {
    176   auto Length = getStrlenWithNull(Builder, Arg);
    177   return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
    178 }
    179 
    180 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
    181                          bool SpecIsCString, bool IsLast) {
    182   if (SpecIsCString && isCString(Arg)) {
    183     return appendString(Builder, Desc, Arg, IsLast);
    184   }
    185   // If the format specifies a string but the argument is not, the frontend will
    186   // have printed a warning. We just rely on undefined behaviour and send the
    187   // argument anyway.
    188   return appendArg(Builder, Desc, Arg, IsLast);
    189 }
    190 
    191 // Scan the format string to locate all specifiers, and mark the ones that
    192 // specify a string, i.e, the "%s" specifier with optional '*' characters.
    193 static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) {
    194   StringRef Str;
    195   if (!getConstantStringInfo(Fmt, Str) || Str.empty())
    196     return;
    197 
    198   static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
    199   size_t SpecPos = 0;
    200   // Skip the first argument, the format string.
    201   unsigned ArgIdx = 1;
    202 
    203   while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {
    204     if (Str[SpecPos + 1] == '%') {
    205       SpecPos += 2;
    206       continue;
    207     }
    208     auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);
    209     if (SpecEnd == StringRef::npos)
    210       return;
    211     auto Spec = Str.slice(SpecPos, SpecEnd + 1);
    212     ArgIdx += Spec.count('*');
    213     if (Str[SpecEnd] == 's') {
    214       BV.set(ArgIdx);
    215     }
    216     SpecPos = SpecEnd + 1;
    217     ++ArgIdx;
    218   }
    219 }
    220 
    221 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder,
    222                                   ArrayRef<Value *> Args) {
    223   auto NumOps = Args.size();
    224   assert(NumOps >= 1);
    225 
    226   auto Fmt = Args[0];
    227   SparseBitVector<8> SpecIsCString;
    228   locateCStrings(SpecIsCString, Fmt);
    229 
    230   auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
    231   Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
    232 
    233   // FIXME: This invokes hostcall once for each argument. We can pack up to
    234   // seven scalar printf arguments in a single hostcall. See the signature of
    235   // callAppendArgs().
    236   for (unsigned int i = 1; i != NumOps; ++i) {
    237     bool IsLast = i == NumOps - 1;
    238     bool IsCString = SpecIsCString.test(i);
    239     Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);
    240   }
    241 
    242   return Builder.CreateTrunc(Desc, Builder.getInt32Ty());
    243 }
    244