1/****************************************************************************
2 * Copyright (C) 2014-2015 Intel Corporation.   All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * @file builder_misc.cpp
24 *
25 * @brief Implementation for miscellaneous builder functions
26 *
27 * Notes:
28 *
29 ******************************************************************************/
30#include "jit_pch.hpp"
31#include "builder.h"
32#include "common/rdtsc_buckets.h"
33
34#include <cstdarg>
35
36extern "C" void CallPrint(const char* fmt, ...);
37
38namespace SwrJit
39{
40    //////////////////////////////////////////////////////////////////////////
41    /// @brief Convert an IEEE 754 32-bit single precision float to an
42    ///        16 bit float with 5 exponent bits and a variable
43    ///        number of mantissa bits.
44    /// @param val - 32-bit float
45    /// @todo Maybe move this outside of this file into a header?
46    static uint16_t ConvertFloat32ToFloat16(float val)
47    {
48        uint32_t sign, exp, mant;
49        uint32_t roundBits;
50
51        // Extract the sign, exponent, and mantissa
52        uint32_t uf = *(uint32_t*)&val;
53        sign        = (uf & 0x80000000) >> 31;
54        exp         = (uf & 0x7F800000) >> 23;
55        mant        = uf & 0x007FFFFF;
56
57        // Check for out of range
58        if (std::isnan(val))
59        {
60            exp  = 0x1F;
61            mant = 0x200;
62            sign = 1; // set the sign bit for NANs
63        }
64        else if (std::isinf(val))
65        {
66            exp  = 0x1f;
67            mant = 0x0;
68        }
69        else if (exp > (0x70 + 0x1E)) // Too big to represent -> max representable value
70        {
71            exp  = 0x1E;
72            mant = 0x3FF;
73        }
74        else if ((exp <= 0x70) && (exp >= 0x66)) // It's a denorm
75        {
76            mant |= 0x00800000;
77            for (; exp <= 0x70; mant >>= 1, exp++)
78                ;
79            exp  = 0;
80            mant = mant >> 13;
81        }
82        else if (exp < 0x66) // Too small to represent -> Zero
83        {
84            exp  = 0;
85            mant = 0;
86        }
87        else
88        {
89            // Saves bits that will be shifted off for rounding
90            roundBits = mant & 0x1FFFu;
91            // convert exponent and mantissa to 16 bit format
92            exp  = exp - 0x70;
93            mant = mant >> 13;
94
95            // Essentially RTZ, but round up if off by only 1 lsb
96            if (roundBits == 0x1FFFu)
97            {
98                mant++;
99                // check for overflow
100                if ((mant & 0xC00u) != 0)
101                    exp++;
102                // make sure only the needed bits are used
103                mant &= 0x3FF;
104            }
105        }
106
107        uint32_t tmpVal = (sign << 15) | (exp << 10) | mant;
108        return (uint16_t)tmpVal;
109    }
110
111    //////////////////////////////////////////////////////////////////////////
112    /// @brief Convert an IEEE 754 16-bit float to an 32-bit single precision
113    ///        float
114    /// @param val - 16-bit float
115    /// @todo Maybe move this outside of this file into a header?
116    static float ConvertFloat16ToFloat32(uint32_t val)
117    {
118        uint32_t result;
119        if ((val & 0x7fff) == 0)
120        {
121            result = ((uint32_t)(val & 0x8000)) << 16;
122        }
123        else if ((val & 0x7c00) == 0x7c00)
124        {
125            result = ((val & 0x3ff) == 0) ? 0x7f800000 : 0x7fc00000;
126            result |= ((uint32_t)val & 0x8000) << 16;
127        }
128        else
129        {
130            uint32_t sign = (val & 0x8000) << 16;
131            uint32_t mant = (val & 0x3ff) << 13;
132            uint32_t exp  = (val >> 10) & 0x1f;
133            if ((exp == 0) && (mant != 0)) // Adjust exponent and mantissa for denormals
134            {
135                mant <<= 1;
136                while (mant < (0x400 << 13))
137                {
138                    exp--;
139                    mant <<= 1;
140                }
141                mant &= (0x3ff << 13);
142            }
143            exp    = ((exp - 15 + 127) & 0xff) << 23;
144            result = sign | exp | mant;
145        }
146
147        return *(float*)&result;
148    }
149
150    Constant* Builder::C(bool i) { return ConstantInt::get(IRB()->getInt1Ty(), (i ? 1 : 0)); }
151
152    Constant* Builder::C(char i) { return ConstantInt::get(IRB()->getInt8Ty(), i); }
153
154    Constant* Builder::C(uint8_t i) { return ConstantInt::get(IRB()->getInt8Ty(), i); }
155
156    Constant* Builder::C(int i) { return ConstantInt::get(IRB()->getInt32Ty(), i); }
157
158    Constant* Builder::C(int64_t i) { return ConstantInt::get(IRB()->getInt64Ty(), i); }
159
160    Constant* Builder::C(uint16_t i) { return ConstantInt::get(mInt16Ty, i); }
161
162    Constant* Builder::C(uint32_t i) { return ConstantInt::get(IRB()->getInt32Ty(), i); }
163
164    Constant* Builder::C(uint64_t i) { return ConstantInt::get(IRB()->getInt64Ty(), i); }
165
166    Constant* Builder::C(float i) { return ConstantFP::get(IRB()->getFloatTy(), i); }
167
168    Constant* Builder::PRED(bool pred)
169    {
170        return ConstantInt::get(IRB()->getInt1Ty(), (pred ? 1 : 0));
171    }
172
173    Value* Builder::VIMMED1(int i)
174    {
175        return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
176    }
177
178    Value* Builder::VIMMED1_16(int i)
179    {
180        return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
181    }
182
183    Value* Builder::VIMMED1(uint32_t i)
184    {
185        return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
186    }
187
188    Value* Builder::VIMMED1_16(uint32_t i)
189    {
190        return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
191    }
192
193    Value* Builder::VIMMED1(float i)
194    {
195        return ConstantVector::getSplat(mVWidth, cast<ConstantFP>(C(i)));
196    }
197
198    Value* Builder::VIMMED1_16(float i)
199    {
200        return ConstantVector::getSplat(mVWidth16, cast<ConstantFP>(C(i)));
201    }
202
203    Value* Builder::VIMMED1(bool i)
204    {
205        return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
206    }
207
208    Value* Builder::VIMMED1_16(bool i)
209    {
210        return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
211    }
212
213    Value* Builder::VUNDEF_IPTR() { return UndefValue::get(VectorType::get(mInt32PtrTy, mVWidth)); }
214
215    Value* Builder::VUNDEF(Type* t) { return UndefValue::get(VectorType::get(t, mVWidth)); }
216
217    Value* Builder::VUNDEF_I() { return UndefValue::get(VectorType::get(mInt32Ty, mVWidth)); }
218
219    Value* Builder::VUNDEF_I_16() { return UndefValue::get(VectorType::get(mInt32Ty, mVWidth16)); }
220
221    Value* Builder::VUNDEF_F() { return UndefValue::get(VectorType::get(mFP32Ty, mVWidth)); }
222
223    Value* Builder::VUNDEF_F_16() { return UndefValue::get(VectorType::get(mFP32Ty, mVWidth16)); }
224
225    Value* Builder::VUNDEF(Type* ty, uint32_t size)
226    {
227        return UndefValue::get(VectorType::get(ty, size));
228    }
229
230    Value* Builder::VBROADCAST(Value* src, const llvm::Twine& name)
231    {
232        // check if src is already a vector
233        if (src->getType()->isVectorTy())
234        {
235            return src;
236        }
237
238        return VECTOR_SPLAT(mVWidth, src, name);
239    }
240
241    Value* Builder::VBROADCAST_16(Value* src)
242    {
243        // check if src is already a vector
244        if (src->getType()->isVectorTy())
245        {
246            return src;
247        }
248
249        return VECTOR_SPLAT(mVWidth16, src);
250    }
251
252    uint32_t Builder::IMMED(Value* v)
253    {
254        SWR_ASSERT(isa<ConstantInt>(v));
255        ConstantInt* pValConst = cast<ConstantInt>(v);
256        return pValConst->getZExtValue();
257    }
258
259    int32_t Builder::S_IMMED(Value* v)
260    {
261        SWR_ASSERT(isa<ConstantInt>(v));
262        ConstantInt* pValConst = cast<ConstantInt>(v);
263        return pValConst->getSExtValue();
264    }
265
266    CallInst* Builder::CALL(Value*                               Callee,
267                            const std::initializer_list<Value*>& argsList,
268                            const llvm::Twine&                   name)
269    {
270        std::vector<Value*> args;
271        for (auto arg : argsList)
272            args.push_back(arg);
273        return CALLA(Callee, args, name);
274    }
275
276    CallInst* Builder::CALL(Value* Callee, Value* arg)
277    {
278        std::vector<Value*> args;
279        args.push_back(arg);
280        return CALLA(Callee, args);
281    }
282
283    CallInst* Builder::CALL2(Value* Callee, Value* arg1, Value* arg2)
284    {
285        std::vector<Value*> args;
286        args.push_back(arg1);
287        args.push_back(arg2);
288        return CALLA(Callee, args);
289    }
290
291    CallInst* Builder::CALL3(Value* Callee, Value* arg1, Value* arg2, Value* arg3)
292    {
293        std::vector<Value*> args;
294        args.push_back(arg1);
295        args.push_back(arg2);
296        args.push_back(arg3);
297        return CALLA(Callee, args);
298    }
299
300    Value* Builder::VRCP(Value* va, const llvm::Twine& name)
301    {
302        return FDIV(VIMMED1(1.0f), va, name); // 1 / a
303    }
304
305    Value* Builder::VPLANEPS(Value* vA, Value* vB, Value* vC, Value*& vX, Value*& vY)
306    {
307        Value* vOut = FMADDPS(vA, vX, vC);
308        vOut        = FMADDPS(vB, vY, vOut);
309        return vOut;
310    }
311
312    //////////////////////////////////////////////////////////////////////////
313    /// @brief insert a JIT call to CallPrint
314    /// - outputs formatted string to both stdout and VS output window
315    /// - DEBUG builds only
316    /// Usage example:
317    ///   PRINT("index %d = 0x%p\n",{C(lane), pIndex});
318    ///   where C(lane) creates a constant value to print, and pIndex is the Value*
319    ///   result from a GEP, printing out the pointer to memory
320    /// @param printStr - constant string to print, which includes format specifiers
321    /// @param printArgs - initializer list of Value*'s to print to std out
322    CallInst* Builder::PRINT(const std::string&                   printStr,
323                             const std::initializer_list<Value*>& printArgs)
324    {
325        // push the arguments to CallPrint into a vector
326        std::vector<Value*> printCallArgs;
327        // save room for the format string.  we still need to modify it for vectors
328        printCallArgs.resize(1);
329
330        // search through the format string for special processing
331        size_t      pos = 0;
332        std::string tempStr(printStr);
333        pos    = tempStr.find('%', pos);
334        auto v = printArgs.begin();
335
336        while ((pos != std::string::npos) && (v != printArgs.end()))
337        {
338            Value* pArg  = *v;
339            Type*  pType = pArg->getType();
340
341            if (pType->isVectorTy())
342            {
343                Type* pContainedType = pType->getContainedType(0);
344
345                if (toupper(tempStr[pos + 1]) == 'X')
346                {
347                    tempStr[pos]     = '0';
348                    tempStr[pos + 1] = 'x';
349                    tempStr.insert(pos + 2, "%08X ");
350                    pos += 7;
351
352                    printCallArgs.push_back(VEXTRACT(pArg, C(0)));
353
354                    std::string vectorFormatStr;
355                    for (uint32_t i = 1; i < pType->getVectorNumElements(); ++i)
356                    {
357                        vectorFormatStr += "0x%08X ";
358                        printCallArgs.push_back(VEXTRACT(pArg, C(i)));
359                    }
360
361                    tempStr.insert(pos, vectorFormatStr);
362                    pos += vectorFormatStr.size();
363                }
364                else if ((tempStr[pos + 1] == 'f') && (pContainedType->isFloatTy()))
365                {
366                    uint32_t i = 0;
367                    for (; i < (pArg->getType()->getVectorNumElements()) - 1; i++)
368                    {
369                        tempStr.insert(pos, std::string("%f "));
370                        pos += 3;
371                        printCallArgs.push_back(
372                            FP_EXT(VEXTRACT(pArg, C(i)), Type::getDoubleTy(JM()->mContext)));
373                    }
374                    printCallArgs.push_back(
375                        FP_EXT(VEXTRACT(pArg, C(i)), Type::getDoubleTy(JM()->mContext)));
376                }
377                else if ((tempStr[pos + 1] == 'd') && (pContainedType->isIntegerTy()))
378                {
379                    uint32_t i = 0;
380                    for (; i < (pArg->getType()->getVectorNumElements()) - 1; i++)
381                    {
382                        tempStr.insert(pos, std::string("%d "));
383                        pos += 3;
384                        printCallArgs.push_back(
385                            S_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
386                    }
387                    printCallArgs.push_back(
388                        S_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
389                }
390                else if ((tempStr[pos + 1] == 'u') && (pContainedType->isIntegerTy()))
391                {
392                    uint32_t i = 0;
393                    for (; i < (pArg->getType()->getVectorNumElements()) - 1; i++)
394                    {
395                        tempStr.insert(pos, std::string("%d "));
396                        pos += 3;
397                        printCallArgs.push_back(
398                            Z_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
399                    }
400                    printCallArgs.push_back(
401                        Z_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
402                }
403            }
404            else
405            {
406                if (toupper(tempStr[pos + 1]) == 'X')
407                {
408                    tempStr[pos] = '0';
409                    tempStr.insert(pos + 1, "x%08");
410                    printCallArgs.push_back(pArg);
411                    pos += 3;
412                }
413                // for %f we need to cast float Values to doubles so that they print out correctly
414                else if ((tempStr[pos + 1] == 'f') && (pType->isFloatTy()))
415                {
416                    printCallArgs.push_back(FP_EXT(pArg, Type::getDoubleTy(JM()->mContext)));
417                    pos++;
418                }
419                else
420                {
421                    printCallArgs.push_back(pArg);
422                }
423            }
424
425            // advance to the next arguement
426            v++;
427            pos = tempStr.find('%', ++pos);
428        }
429
430        // create global variable constant string
431        Constant*       constString = ConstantDataArray::getString(JM()->mContext, tempStr, true);
432        GlobalVariable* gvPtr       = new GlobalVariable(
433            constString->getType(), true, GlobalValue::InternalLinkage, constString, "printStr");
434        JM()->mpCurrentModule->getGlobalList().push_back(gvPtr);
435
436        // get a pointer to the first character in the constant string array
437        std::vector<Constant*> geplist{C(0), C(0)};
438        Constant* strGEP = ConstantExpr::getGetElementPtr(nullptr, gvPtr, geplist, false);
439
440        // insert the pointer to the format string in the argument vector
441        printCallArgs[0] = strGEP;
442
443        // get pointer to CallPrint function and insert decl into the module if needed
444        std::vector<Type*> args;
445        args.push_back(PointerType::get(mInt8Ty, 0));
446        FunctionType* callPrintTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, true);
447        Function*     callPrintFn =
448#if LLVM_VERSION_MAJOR >= 9
449            cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy).getCallee());
450#else
451            cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy));
452#endif
453
454        // if we haven't yet added the symbol to the symbol table
455        if ((sys::DynamicLibrary::SearchForAddressOfSymbol("CallPrint")) == nullptr)
456        {
457            sys::DynamicLibrary::AddSymbol("CallPrint", (void*)&CallPrint);
458        }
459
460        // insert a call to CallPrint
461        return CALLA(callPrintFn, printCallArgs);
462    }
463
464    //////////////////////////////////////////////////////////////////////////
465    /// @brief Wrapper around PRINT with initializer list.
466    CallInst* Builder::PRINT(const std::string& printStr) { return PRINT(printStr, {}); }
467
468    Value* Builder::EXTRACT_16(Value* x, uint32_t imm)
469    {
470        if (imm == 0)
471        {
472            return VSHUFFLE(x, UndefValue::get(x->getType()), {0, 1, 2, 3, 4, 5, 6, 7});
473        }
474        else
475        {
476            return VSHUFFLE(x, UndefValue::get(x->getType()), {8, 9, 10, 11, 12, 13, 14, 15});
477        }
478    }
479
480    Value* Builder::JOIN_16(Value* a, Value* b)
481    {
482        return VSHUFFLE(a, b, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
483    }
484
485    //////////////////////////////////////////////////////////////////////////
486    /// @brief convert x86 <N x float> mask to llvm <N x i1> mask
487    Value* Builder::MASK(Value* vmask)
488    {
489        Value* src = BITCAST(vmask, mSimdInt32Ty);
490        return ICMP_SLT(src, VIMMED1(0));
491    }
492
493    Value* Builder::MASK_16(Value* vmask)
494    {
495        Value* src = BITCAST(vmask, mSimd16Int32Ty);
496        return ICMP_SLT(src, VIMMED1_16(0));
497    }
498
499    //////////////////////////////////////////////////////////////////////////
500    /// @brief convert llvm <N x i1> mask to x86 <N x i32> mask
501    Value* Builder::VMASK(Value* mask) { return S_EXT(mask, mSimdInt32Ty); }
502
503    Value* Builder::VMASK_16(Value* mask) { return S_EXT(mask, mSimd16Int32Ty); }
504
505    /// @brief Convert <Nxi1> llvm mask to integer
506    Value* Builder::VMOVMSK(Value* mask)
507    {
508        SWR_ASSERT(mask->getType()->getVectorElementType() == mInt1Ty);
509        uint32_t numLanes = mask->getType()->getVectorNumElements();
510        Value*   i32Result;
511        if (numLanes == 8)
512        {
513            i32Result = BITCAST(mask, mInt8Ty);
514        }
515        else if (numLanes == 16)
516        {
517            i32Result = BITCAST(mask, mInt16Ty);
518        }
519        else
520        {
521            SWR_ASSERT("Unsupported vector width");
522            i32Result = BITCAST(mask, mInt8Ty);
523        }
524        return Z_EXT(i32Result, mInt32Ty);
525    }
526
527    //////////////////////////////////////////////////////////////////////////
528    /// @brief Generate a VPSHUFB operation in LLVM IR.  If not
529    /// supported on the underlying platform, emulate it
530    /// @param a - 256bit SIMD(32x8bit) of 8bit integer values
531    /// @param b - 256bit SIMD(32x8bit) of 8bit integer mask values
532    /// Byte masks in lower 128 lane of b selects 8 bit values from lower
533    /// 128bits of a, and vice versa for the upper lanes.  If the mask
534    /// value is negative, '0' is inserted.
535    Value* Builder::PSHUFB(Value* a, Value* b)
536    {
537        Value* res;
538        // use avx2 pshufb instruction if available
539        if (JM()->mArch.AVX2())
540        {
541            res = VPSHUFB(a, b);
542        }
543        else
544        {
545            Constant* cB = dyn_cast<Constant>(b);
546            // number of 8 bit elements in b
547            uint32_t numElms = cast<VectorType>(cB->getType())->getNumElements();
548            // output vector
549            Value* vShuf = UndefValue::get(VectorType::get(mInt8Ty, numElms));
550
551            // insert an 8 bit value from the high and low lanes of a per loop iteration
552            numElms /= 2;
553            for (uint32_t i = 0; i < numElms; i++)
554            {
555                ConstantInt* cLow128b  = cast<ConstantInt>(cB->getAggregateElement(i));
556                ConstantInt* cHigh128b = cast<ConstantInt>(cB->getAggregateElement(i + numElms));
557
558                // extract values from constant mask
559                char valLow128bLane  = (char)(cLow128b->getSExtValue());
560                char valHigh128bLane = (char)(cHigh128b->getSExtValue());
561
562                Value* insertValLow128b;
563                Value* insertValHigh128b;
564
565                // if the mask value is negative, insert a '0' in the respective output position
566                // otherwise, lookup the value at mask position (bits 3..0 of the respective mask
567                // byte) in a and insert in output vector
568                insertValLow128b =
569                    (valLow128bLane < 0) ? C((char)0) : VEXTRACT(a, C((valLow128bLane & 0xF)));
570                insertValHigh128b = (valHigh128bLane < 0)
571                                        ? C((char)0)
572                                        : VEXTRACT(a, C((valHigh128bLane & 0xF) + numElms));
573
574                vShuf = VINSERT(vShuf, insertValLow128b, i);
575                vShuf = VINSERT(vShuf, insertValHigh128b, (i + numElms));
576            }
577            res = vShuf;
578        }
579        return res;
580    }
581
582    //////////////////////////////////////////////////////////////////////////
583    /// @brief Generate a VPSHUFB operation (sign extend 8 8bit values to 32
584    /// bits)in LLVM IR.  If not supported on the underlying platform, emulate it
585    /// @param a - 128bit SIMD lane(16x8bit) of 8bit integer values.  Only
586    /// lower 8 values are used.
587    Value* Builder::PMOVSXBD(Value* a)
588    {
589        // VPMOVSXBD output type
590        Type* v8x32Ty = VectorType::get(mInt32Ty, 8);
591        // Extract 8 values from 128bit lane and sign extend
592        return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);
593    }
594
595    //////////////////////////////////////////////////////////////////////////
596    /// @brief Generate a VPSHUFB operation (sign extend 8 16bit values to 32
597    /// bits)in LLVM IR.  If not supported on the underlying platform, emulate it
598    /// @param a - 128bit SIMD lane(8x16bit) of 16bit integer values.
599    Value* Builder::PMOVSXWD(Value* a)
600    {
601        // VPMOVSXWD output type
602        Type* v8x32Ty = VectorType::get(mInt32Ty, 8);
603        // Extract 8 values from 128bit lane and sign extend
604        return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);
605    }
606
607    //////////////////////////////////////////////////////////////////////////
608    /// @brief Generate a VCVTPH2PS operation (float16->float32 conversion)
609    /// in LLVM IR.  If not supported on the underlying platform, emulate it
610    /// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.
611    Value* Builder::CVTPH2PS(Value* a, const llvm::Twine& name)
612    {
613        if (JM()->mArch.F16C())
614        {
615            return VCVTPH2PS(a, name);
616        }
617        else
618        {
619            FunctionType* pFuncTy   = FunctionType::get(mFP32Ty, mInt16Ty);
620            Function*     pCvtPh2Ps = cast<Function>(
621#if LLVM_VERSION_MAJOR >= 9
622                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat16ToFloat32", pFuncTy).getCallee());
623#else
624                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat16ToFloat32", pFuncTy));
625#endif
626
627            if (sys::DynamicLibrary::SearchForAddressOfSymbol("ConvertFloat16ToFloat32") == nullptr)
628            {
629                sys::DynamicLibrary::AddSymbol("ConvertFloat16ToFloat32",
630                                               (void*)&ConvertFloat16ToFloat32);
631            }
632
633            Value* pResult = UndefValue::get(mSimdFP32Ty);
634            for (uint32_t i = 0; i < mVWidth; ++i)
635            {
636                Value* pSrc  = VEXTRACT(a, C(i));
637                Value* pConv = CALL(pCvtPh2Ps, std::initializer_list<Value*>{pSrc});
638                pResult      = VINSERT(pResult, pConv, C(i));
639            }
640
641            pResult->setName(name);
642            return pResult;
643        }
644    }
645
646    //////////////////////////////////////////////////////////////////////////
647    /// @brief Generate a VCVTPS2PH operation (float32->float16 conversion)
648    /// in LLVM IR.  If not supported on the underlying platform, emulate it
649    /// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.
650    Value* Builder::CVTPS2PH(Value* a, Value* rounding)
651    {
652        if (JM()->mArch.F16C())
653        {
654            return VCVTPS2PH(a, rounding);
655        }
656        else
657        {
658            // call scalar C function for now
659            FunctionType* pFuncTy   = FunctionType::get(mInt16Ty, mFP32Ty);
660            Function*     pCvtPs2Ph = cast<Function>(
661#if LLVM_VERSION_MAJOR >= 9
662                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy).getCallee());
663#else
664                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy));
665#endif
666
667            if (sys::DynamicLibrary::SearchForAddressOfSymbol("ConvertFloat32ToFloat16") == nullptr)
668            {
669                sys::DynamicLibrary::AddSymbol("ConvertFloat32ToFloat16",
670                                               (void*)&ConvertFloat32ToFloat16);
671            }
672
673            Value* pResult = UndefValue::get(mSimdInt16Ty);
674            for (uint32_t i = 0; i < mVWidth; ++i)
675            {
676                Value* pSrc  = VEXTRACT(a, C(i));
677                Value* pConv = CALL(pCvtPs2Ph, std::initializer_list<Value*>{pSrc});
678                pResult      = VINSERT(pResult, pConv, C(i));
679            }
680
681            return pResult;
682        }
683    }
684
685    Value* Builder::PMAXSD(Value* a, Value* b)
686    {
687        Value* cmp = ICMP_SGT(a, b);
688        return SELECT(cmp, a, b);
689    }
690
691    Value* Builder::PMINSD(Value* a, Value* b)
692    {
693        Value* cmp = ICMP_SLT(a, b);
694        return SELECT(cmp, a, b);
695    }
696
697    Value* Builder::PMAXUD(Value* a, Value* b)
698    {
699        Value* cmp = ICMP_UGT(a, b);
700        return SELECT(cmp, a, b);
701    }
702
703    Value* Builder::PMINUD(Value* a, Value* b)
704    {
705        Value* cmp = ICMP_ULT(a, b);
706        return SELECT(cmp, a, b);
707    }
708
709    // Helper function to create alloca in entry block of function
710    Value* Builder::CreateEntryAlloca(Function* pFunc, Type* pType)
711    {
712        auto saveIP = IRB()->saveIP();
713        IRB()->SetInsertPoint(&pFunc->getEntryBlock(), pFunc->getEntryBlock().begin());
714        Value* pAlloca = ALLOCA(pType);
715        if (saveIP.isSet())
716            IRB()->restoreIP(saveIP);
717        return pAlloca;
718    }
719
720    Value* Builder::CreateEntryAlloca(Function* pFunc, Type* pType, Value* pArraySize)
721    {
722        auto saveIP = IRB()->saveIP();
723        IRB()->SetInsertPoint(&pFunc->getEntryBlock(), pFunc->getEntryBlock().begin());
724        Value* pAlloca = ALLOCA(pType, pArraySize);
725        if (saveIP.isSet())
726            IRB()->restoreIP(saveIP);
727        return pAlloca;
728    }
729
730    Value* Builder::VABSPS(Value* a)
731    {
732        Value* asInt  = BITCAST(a, mSimdInt32Ty);
733        Value* result = BITCAST(AND(asInt, VIMMED1(0x7fffffff)), mSimdFP32Ty);
734        return result;
735    }
736
737    Value* Builder::ICLAMP(Value* src, Value* low, Value* high, const llvm::Twine& name)
738    {
739        Value* lowCmp = ICMP_SLT(src, low);
740        Value* ret    = SELECT(lowCmp, low, src);
741
742        Value* highCmp = ICMP_SGT(ret, high);
743        ret            = SELECT(highCmp, high, ret, name);
744
745        return ret;
746    }
747
748    Value* Builder::FCLAMP(Value* src, Value* low, Value* high)
749    {
750        Value* lowCmp = FCMP_OLT(src, low);
751        Value* ret    = SELECT(lowCmp, low, src);
752
753        Value* highCmp = FCMP_OGT(ret, high);
754        ret            = SELECT(highCmp, high, ret);
755
756        return ret;
757    }
758
759    Value* Builder::FCLAMP(Value* src, float low, float high)
760    {
761        Value* result = VMAXPS(src, VIMMED1(low));
762        result        = VMINPS(result, VIMMED1(high));
763
764        return result;
765    }
766
767    Value* Builder::FMADDPS(Value* a, Value* b, Value* c)
768    {
769        Value* vOut;
770        // This maps to LLVM fmuladd intrinsic
771        vOut = VFMADDPS(a, b, c);
772        return vOut;
773    }
774
775    //////////////////////////////////////////////////////////////////////////
776    /// @brief pop count on vector mask (e.g. <8 x i1>)
777    Value* Builder::VPOPCNT(Value* a) { return POPCNT(VMOVMSK(a)); }
778
779    //////////////////////////////////////////////////////////////////////////
780    /// @brief Float / Fixed-point conversions
781    //////////////////////////////////////////////////////////////////////////
782    Value* Builder::VCVT_F32_FIXED_SI(Value*             vFloat,
783                                      uint32_t           numIntBits,
784                                      uint32_t           numFracBits,
785                                      const llvm::Twine& name)
786    {
787        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
788        Value* fixed = nullptr;
789        {
790            // Do round to nearest int on fractional bits first
791            // Not entirely perfect for negative numbers, but close enough
792            vFloat = VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),
793                            C(_MM_FROUND_TO_NEAREST_INT));
794            vFloat = FMUL(vFloat, VIMMED1(1.0f / float(1 << numFracBits)));
795
796            // TODO: Handle INF, NAN, overflow / underflow, etc.
797
798            Value* vSgn      = FCMP_OLT(vFloat, VIMMED1(0.0f));
799            Value* vFloatInt = BITCAST(vFloat, mSimdInt32Ty);
800            Value* vFixed    = AND(vFloatInt, VIMMED1((1 << 23) - 1));
801            vFixed           = OR(vFixed, VIMMED1(1 << 23));
802            vFixed           = SELECT(vSgn, NEG(vFixed), vFixed);
803
804            Value* vExp = LSHR(SHL(vFloatInt, VIMMED1(1)), VIMMED1(24));
805            vExp        = SUB(vExp, VIMMED1(127));
806
807            Value* vExtraBits = SUB(VIMMED1(23 - numFracBits), vExp);
808
809            fixed = ASHR(vFixed, vExtraBits, name);
810        }
811
812        return fixed;
813    }
814
815    Value* Builder::VCVT_FIXED_SI_F32(Value*             vFixed,
816                                      uint32_t           numIntBits,
817                                      uint32_t           numFracBits,
818                                      const llvm::Twine& name)
819    {
820        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
821        uint32_t extraBits = 32 - numIntBits - numFracBits;
822        if (numIntBits && extraBits)
823        {
824            // Sign extend
825            Value* shftAmt = VIMMED1(extraBits);
826            vFixed         = ASHR(SHL(vFixed, shftAmt), shftAmt);
827        }
828
829        Value* fVal  = VIMMED1(0.0f);
830        Value* fFrac = VIMMED1(0.0f);
831        if (numIntBits)
832        {
833            fVal = SI_TO_FP(ASHR(vFixed, VIMMED1(numFracBits)), mSimdFP32Ty, name);
834        }
835
836        if (numFracBits)
837        {
838            fFrac = UI_TO_FP(AND(vFixed, VIMMED1((1 << numFracBits) - 1)), mSimdFP32Ty);
839            fFrac = FDIV(fFrac, VIMMED1(float(1 << numFracBits)), name);
840        }
841
842        return FADD(fVal, fFrac, name);
843    }
844
845    Value* Builder::VCVT_F32_FIXED_UI(Value*             vFloat,
846                                      uint32_t           numIntBits,
847                                      uint32_t           numFracBits,
848                                      const llvm::Twine& name)
849    {
850        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
851        Value* fixed = nullptr;
852        // KNOB_SIM_FAST_MATH?  Below works correctly from a precision
853        // standpoint...
854        {
855            fixed = FP_TO_UI(VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),
856                                    C(_MM_FROUND_TO_NEAREST_INT)),
857                             mSimdInt32Ty);
858        }
859        return fixed;
860    }
861
862    Value* Builder::VCVT_FIXED_UI_F32(Value*             vFixed,
863                                      uint32_t           numIntBits,
864                                      uint32_t           numFracBits,
865                                      const llvm::Twine& name)
866    {
867        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
868        uint32_t extraBits = 32 - numIntBits - numFracBits;
869        if (numIntBits && extraBits)
870        {
871            // Sign extend
872            Value* shftAmt = VIMMED1(extraBits);
873            vFixed         = ASHR(SHL(vFixed, shftAmt), shftAmt);
874        }
875
876        Value* fVal  = VIMMED1(0.0f);
877        Value* fFrac = VIMMED1(0.0f);
878        if (numIntBits)
879        {
880            fVal = UI_TO_FP(LSHR(vFixed, VIMMED1(numFracBits)), mSimdFP32Ty, name);
881        }
882
883        if (numFracBits)
884        {
885            fFrac = UI_TO_FP(AND(vFixed, VIMMED1((1 << numFracBits) - 1)), mSimdFP32Ty);
886            fFrac = FDIV(fFrac, VIMMED1(float(1 << numFracBits)), name);
887        }
888
889        return FADD(fVal, fFrac, name);
890    }
891
892    //////////////////////////////////////////////////////////////////////////
893    /// @brief C functions called by LLVM IR
894    //////////////////////////////////////////////////////////////////////////
895
896    Value* Builder::VEXTRACTI128(Value* a, Constant* imm8)
897    {
898        bool                      flag = !imm8->isZeroValue();
899        SmallVector<Constant*, 8> idx;
900        for (unsigned i = 0; i < mVWidth / 2; i++)
901        {
902            idx.push_back(C(flag ? i + mVWidth / 2 : i));
903        }
904        return VSHUFFLE(a, VUNDEF_I(), ConstantVector::get(idx));
905    }
906
907    Value* Builder::VINSERTI128(Value* a, Value* b, Constant* imm8)
908    {
909        bool                      flag = !imm8->isZeroValue();
910        SmallVector<Constant*, 8> idx;
911        for (unsigned i = 0; i < mVWidth; i++)
912        {
913            idx.push_back(C(i));
914        }
915        Value* inter = VSHUFFLE(b, VUNDEF_I(), ConstantVector::get(idx));
916
917        SmallVector<Constant*, 8> idx2;
918        for (unsigned i = 0; i < mVWidth / 2; i++)
919        {
920            idx2.push_back(C(flag ? i : i + mVWidth));
921        }
922        for (unsigned i = mVWidth / 2; i < mVWidth; i++)
923        {
924            idx2.push_back(C(flag ? i + mVWidth / 2 : i));
925        }
926        return VSHUFFLE(a, inter, ConstantVector::get(idx2));
927    }
928
929    // rdtsc buckets macros
930    void Builder::RDTSC_START(Value* pBucketMgr, Value* pId)
931    {
932        // @todo due to an issue with thread local storage propagation in llvm, we can only safely
933        // call into buckets framework when single threaded
934        if (KNOB_SINGLE_THREADED)
935        {
936            std::vector<Type*> args{
937                PointerType::get(mInt32Ty, 0), // pBucketMgr
938                mInt32Ty                       // id
939            };
940
941            FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);
942            Function*     pFunc   = cast<Function>(
943#if LLVM_VERSION_MAJOR >= 9
944                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy).getCallee());
945#else
946                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy));
947#endif
948            if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StartBucket") ==
949                nullptr)
950            {
951                sys::DynamicLibrary::AddSymbol("BucketManager_StartBucket",
952                                               (void*)&BucketManager_StartBucket);
953            }
954
955            CALL(pFunc, {pBucketMgr, pId});
956        }
957    }
958
959    void Builder::RDTSC_STOP(Value* pBucketMgr, Value* pId)
960    {
961        // @todo due to an issue with thread local storage propagation in llvm, we can only safely
962        // call into buckets framework when single threaded
963        if (KNOB_SINGLE_THREADED)
964        {
965            std::vector<Type*> args{
966                PointerType::get(mInt32Ty, 0), // pBucketMgr
967                mInt32Ty                       // id
968            };
969
970            FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);
971            Function*     pFunc   = cast<Function>(
972#if LLVM_VERSION_MAJOR >=9
973                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy).getCallee());
974#else
975                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy));
976#endif
977            if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StopBucket") ==
978                nullptr)
979            {
980                sys::DynamicLibrary::AddSymbol("BucketManager_StopBucket",
981                                               (void*)&BucketManager_StopBucket);
982            }
983
984            CALL(pFunc, {pBucketMgr, pId});
985        }
986    }
987
988    uint32_t Builder::GetTypeSize(Type* pType)
989    {
990        if (pType->isStructTy())
991        {
992            uint32_t numElems = pType->getStructNumElements();
993            Type*    pElemTy  = pType->getStructElementType(0);
994            return numElems * GetTypeSize(pElemTy);
995        }
996
997        if (pType->isArrayTy())
998        {
999            uint32_t numElems = pType->getArrayNumElements();
1000            Type*    pElemTy  = pType->getArrayElementType();
1001            return numElems * GetTypeSize(pElemTy);
1002        }
1003
1004        if (pType->isIntegerTy())
1005        {
1006            uint32_t bitSize = pType->getIntegerBitWidth();
1007            return bitSize / 8;
1008        }
1009
1010        if (pType->isFloatTy())
1011        {
1012            return 4;
1013        }
1014
1015        if (pType->isHalfTy())
1016        {
1017            return 2;
1018        }
1019
1020        if (pType->isDoubleTy())
1021        {
1022            return 8;
1023        }
1024
1025        SWR_ASSERT(false, "Unimplemented type.");
1026        return 0;
1027    }
1028} // namespace SwrJit
1029