1/****************************************************************************
2 * Copyright (C) 2014-2015 Intel Corporation.   All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * @file builder_misc.cpp
24 *
25 * @brief Implementation for miscellaneous builder functions
26 *
27 * Notes:
28 *
29 ******************************************************************************/
30#include "jit_pch.hpp"
31#include "builder.h"
32#include "common/rdtsc_buckets.h"
33
34#include <cstdarg>
35
36extern "C" void CallPrint(const char* fmt, ...);
37
38namespace SwrJit
39{
40    //////////////////////////////////////////////////////////////////////////
41    /// @brief Convert an IEEE 754 32-bit single precision float to an
42    ///        16 bit float with 5 exponent bits and a variable
43    ///        number of mantissa bits.
44    /// @param val - 32-bit float
45    /// @todo Maybe move this outside of this file into a header?
46    static uint16_t ConvertFloat32ToFloat16(float val)
47    {
48        uint32_t sign, exp, mant;
49        uint32_t roundBits;
50
51        // Extract the sign, exponent, and mantissa
52        uint32_t uf = *(uint32_t*)&val;
53        sign        = (uf & 0x80000000) >> 31;
54        exp         = (uf & 0x7F800000) >> 23;
55        mant        = uf & 0x007FFFFF;
56
57        // Check for out of range
58        if (std::isnan(val))
59        {
60            exp  = 0x1F;
61            mant = 0x200;
62            sign = 1; // set the sign bit for NANs
63        }
64        else if (std::isinf(val))
65        {
66            exp  = 0x1f;
67            mant = 0x0;
68        }
69        else if (exp > (0x70 + 0x1E)) // Too big to represent -> max representable value
70        {
71            exp  = 0x1E;
72            mant = 0x3FF;
73        }
74        else if ((exp <= 0x70) && (exp >= 0x66)) // It's a denorm
75        {
76            mant |= 0x00800000;
77            for (; exp <= 0x70; mant >>= 1, exp++)
78                ;
79            exp  = 0;
80            mant = mant >> 13;
81        }
82        else if (exp < 0x66) // Too small to represent -> Zero
83        {
84            exp  = 0;
85            mant = 0;
86        }
87        else
88        {
89            // Saves bits that will be shifted off for rounding
90            roundBits = mant & 0x1FFFu;
91            // convert exponent and mantissa to 16 bit format
92            exp  = exp - 0x70;
93            mant = mant >> 13;
94
95            // Essentially RTZ, but round up if off by only 1 lsb
96            if (roundBits == 0x1FFFu)
97            {
98                mant++;
99                // check for overflow
100                if ((mant & 0xC00u) != 0)
101                    exp++;
102                // make sure only the needed bits are used
103                mant &= 0x3FF;
104            }
105        }
106
107        uint32_t tmpVal = (sign << 15) | (exp << 10) | mant;
108        return (uint16_t)tmpVal;
109    }
110
111    Constant* Builder::C(bool i) { return ConstantInt::get(IRB()->getInt1Ty(), (i ? 1 : 0)); }
112
113    Constant* Builder::C(char i) { return ConstantInt::get(IRB()->getInt8Ty(), i); }
114
115    Constant* Builder::C(uint8_t i) { return ConstantInt::get(IRB()->getInt8Ty(), i); }
116
117    Constant* Builder::C(int i) { return ConstantInt::get(IRB()->getInt32Ty(), i); }
118
119    Constant* Builder::C(int64_t i) { return ConstantInt::get(IRB()->getInt64Ty(), i); }
120
121    Constant* Builder::C(uint16_t i) { return ConstantInt::get(mInt16Ty, i); }
122
123    Constant* Builder::C(uint32_t i) { return ConstantInt::get(IRB()->getInt32Ty(), i); }
124
125    Constant* Builder::C(uint64_t i) { return ConstantInt::get(IRB()->getInt64Ty(), i); }
126
127    Constant* Builder::C(float i) { return ConstantFP::get(IRB()->getFloatTy(), i); }
128
129    Constant* Builder::PRED(bool pred)
130    {
131        return ConstantInt::get(IRB()->getInt1Ty(), (pred ? 1 : 0));
132    }
133
134    Value* Builder::VIMMED1(uint64_t i)
135    {
136#if LLVM_VERSION_MAJOR <= 10
137        return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
138#elif LLVM_VERSION_MAJOR == 11
139        return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));
140#else
141        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
142#endif
143    }
144
145    Value* Builder::VIMMED1_16(uint64_t i)
146    {
147#if LLVM_VERSION_MAJOR <= 10
148        return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
149#elif LLVM_VERSION_MAJOR == 11
150        return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));
151#else
152        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
153#endif
154    }
155
156    Value* Builder::VIMMED1(int i)
157    {
158#if LLVM_VERSION_MAJOR <= 10
159        return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
160#elif LLVM_VERSION_MAJOR == 11
161        return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));
162#else
163        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
164#endif
165    }
166
167    Value* Builder::VIMMED1_16(int i)
168    {
169#if LLVM_VERSION_MAJOR <= 10
170        return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
171#elif LLVM_VERSION_MAJOR == 11
172        return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));
173#else
174        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
175#endif
176    }
177
178    Value* Builder::VIMMED1(uint32_t i)
179    {
180#if LLVM_VERSION_MAJOR <= 10
181        return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
182#elif LLVM_VERSION_MAJOR == 11
183        return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));
184#else
185        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
186#endif
187    }
188
189    Value* Builder::VIMMED1_16(uint32_t i)
190    {
191#if LLVM_VERSION_MAJOR <= 10
192        return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
193#elif LLVM_VERSION_MAJOR == 11
194        return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));
195#else
196        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
197#endif
198    }
199
200    Value* Builder::VIMMED1(float i)
201    {
202#if LLVM_VERSION_MAJOR <= 10
203        return ConstantVector::getSplat(mVWidth, cast<ConstantFP>(C(i)));
204#elif LLVM_VERSION_MAJOR == 11
205        return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantFP>(C(i)));
206#else
207        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantFP>(C(i)));
208#endif
209    }
210
211    Value* Builder::VIMMED1_16(float i)
212    {
213#if LLVM_VERSION_MAJOR <= 10
214        return ConstantVector::getSplat(mVWidth16, cast<ConstantFP>(C(i)));
215#elif LLVM_VERSION_MAJOR == 11
216        return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantFP>(C(i)));
217#else
218        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantFP>(C(i)));
219#endif
220    }
221
222    Value* Builder::VIMMED1(bool i)
223    {
224#if LLVM_VERSION_MAJOR <= 10
225        return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
226#elif LLVM_VERSION_MAJOR == 11
227        return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));
228#else
229        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
230#endif
231    }
232
233    Value* Builder::VIMMED1_16(bool i)
234    {
235#if LLVM_VERSION_MAJOR <= 10
236        return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
237#elif LLVM_VERSION_MAJOR == 11
238        return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));
239#else
240        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
241#endif
242    }
243
244    Value* Builder::VUNDEF_IPTR() { return UndefValue::get(getVectorType(mInt32PtrTy, mVWidth)); }
245
246    Value* Builder::VUNDEF(Type* t) { return UndefValue::get(getVectorType(t, mVWidth)); }
247
248    Value* Builder::VUNDEF_I() { return UndefValue::get(getVectorType(mInt32Ty, mVWidth)); }
249
250    Value* Builder::VUNDEF_I_16() { return UndefValue::get(getVectorType(mInt32Ty, mVWidth16)); }
251
252    Value* Builder::VUNDEF_F() { return UndefValue::get(getVectorType(mFP32Ty, mVWidth)); }
253
254    Value* Builder::VUNDEF_F_16() { return UndefValue::get(getVectorType(mFP32Ty, mVWidth16)); }
255
256    Value* Builder::VUNDEF(Type* ty, uint32_t size)
257    {
258        return UndefValue::get(getVectorType(ty, size));
259    }
260
261    Value* Builder::VBROADCAST(Value* src, const llvm::Twine& name)
262    {
263        // check if src is already a vector
264        if (src->getType()->isVectorTy())
265        {
266            return src;
267        }
268
269        return VECTOR_SPLAT(mVWidth, src, name);
270    }
271
272    Value* Builder::VBROADCAST_16(Value* src)
273    {
274        // check if src is already a vector
275        if (src->getType()->isVectorTy())
276        {
277            return src;
278        }
279
280        return VECTOR_SPLAT(mVWidth16, src);
281    }
282
283    uint32_t Builder::IMMED(Value* v)
284    {
285        SWR_ASSERT(isa<ConstantInt>(v));
286        ConstantInt* pValConst = cast<ConstantInt>(v);
287        return pValConst->getZExtValue();
288    }
289
290    int32_t Builder::S_IMMED(Value* v)
291    {
292        SWR_ASSERT(isa<ConstantInt>(v));
293        ConstantInt* pValConst = cast<ConstantInt>(v);
294        return pValConst->getSExtValue();
295    }
296
297    CallInst* Builder::CALL(Value*                               Callee,
298                            const std::initializer_list<Value*>& argsList,
299                            const llvm::Twine&                   name)
300    {
301        std::vector<Value*> args;
302        for (auto arg : argsList)
303            args.push_back(arg);
304#if LLVM_VERSION_MAJOR >= 11
305        // see comment to CALLA(Callee) function in the header
306        return CALLA(FunctionCallee(cast<Function>(Callee)), args, name);
307#else
308        return CALLA(Callee, args, name);
309#endif
310    }
311
312    CallInst* Builder::CALL(Value* Callee, Value* arg)
313    {
314        std::vector<Value*> args;
315        args.push_back(arg);
316#if LLVM_VERSION_MAJOR >= 11
317        // see comment to CALLA(Callee) function in the header
318        return CALLA(FunctionCallee(cast<Function>(Callee)), args);
319#else
320        return CALLA(Callee, args);
321#endif
322    }
323
324    CallInst* Builder::CALL2(Value* Callee, Value* arg1, Value* arg2)
325    {
326        std::vector<Value*> args;
327        args.push_back(arg1);
328        args.push_back(arg2);
329#if LLVM_VERSION_MAJOR >= 11
330        // see comment to CALLA(Callee) function in the header
331        return CALLA(FunctionCallee(cast<Function>(Callee)), args);
332#else
333        return CALLA(Callee, args);
334#endif
335    }
336
337    CallInst* Builder::CALL3(Value* Callee, Value* arg1, Value* arg2, Value* arg3)
338    {
339        std::vector<Value*> args;
340        args.push_back(arg1);
341        args.push_back(arg2);
342        args.push_back(arg3);
343#if LLVM_VERSION_MAJOR >= 11
344        // see comment to CALLA(Callee) function in the header
345        return CALLA(FunctionCallee(cast<Function>(Callee)), args);
346#else
347        return CALLA(Callee, args);
348#endif
349    }
350
351    Value* Builder::VRCP(Value* va, const llvm::Twine& name)
352    {
353        return FDIV(VIMMED1(1.0f), va, name); // 1 / a
354    }
355
356    Value* Builder::VPLANEPS(Value* vA, Value* vB, Value* vC, Value*& vX, Value*& vY)
357    {
358        Value* vOut = FMADDPS(vA, vX, vC);
359        vOut        = FMADDPS(vB, vY, vOut);
360        return vOut;
361    }
362
363    //////////////////////////////////////////////////////////////////////////
364    /// @brief insert a JIT call to CallPrint
365    /// - outputs formatted string to both stdout and VS output window
366    /// - DEBUG builds only
367    /// Usage example:
368    ///   PRINT("index %d = 0x%p\n",{C(lane), pIndex});
369    ///   where C(lane) creates a constant value to print, and pIndex is the Value*
370    ///   result from a GEP, printing out the pointer to memory
371    /// @param printStr - constant string to print, which includes format specifiers
372    /// @param printArgs - initializer list of Value*'s to print to std out
373    CallInst* Builder::PRINT(const std::string&                   printStr,
374                             const std::initializer_list<Value*>& printArgs)
375    {
376        // push the arguments to CallPrint into a vector
377        std::vector<Value*> printCallArgs;
378        // save room for the format string.  we still need to modify it for vectors
379        printCallArgs.resize(1);
380
381        // search through the format string for special processing
382        size_t      pos = 0;
383        std::string tempStr(printStr);
384        pos    = tempStr.find('%', pos);
385        auto v = printArgs.begin();
386
387        while ((pos != std::string::npos) && (v != printArgs.end()))
388        {
389            Value* pArg  = *v;
390            Type*  pType = pArg->getType();
391
392            if (pType->isVectorTy())
393            {
394                Type* pContainedType = pType->getContainedType(0);
395#if LLVM_VERSION_MAJOR >= 12
396                FixedVectorType* pVectorType = cast<FixedVectorType>(pType);
397#elif LLVM_VERSION_MAJOR >= 11
398                VectorType* pVectorType = cast<VectorType>(pType);
399#endif
400                if (toupper(tempStr[pos + 1]) == 'X')
401                {
402                    tempStr[pos]     = '0';
403                    tempStr[pos + 1] = 'x';
404                    tempStr.insert(pos + 2, "%08X ");
405                    pos += 7;
406
407                    printCallArgs.push_back(VEXTRACT(pArg, C(0)));
408
409                    std::string vectorFormatStr;
410#if LLVM_VERSION_MAJOR >= 11
411                    for (uint32_t i = 1; i < pVectorType->getNumElements(); ++i)
412#else
413                    for (uint32_t i = 1; i < pType->getVectorNumElements(); ++i)
414#endif
415                    {
416                        vectorFormatStr += "0x%08X ";
417                        printCallArgs.push_back(VEXTRACT(pArg, C(i)));
418                    }
419
420                    tempStr.insert(pos, vectorFormatStr);
421                    pos += vectorFormatStr.size();
422                }
423                else if ((tempStr[pos + 1] == 'f') && (pContainedType->isFloatTy()))
424                {
425                    uint32_t i = 0;
426#if LLVM_VERSION_MAJOR >= 11
427                    for (; i < pVectorType->getNumElements() - 1; i++)
428#else
429                    for (; i < pType->getVectorNumElements() - 1; i++)
430#endif
431                    {
432                        tempStr.insert(pos, std::string("%f "));
433                        pos += 3;
434                        printCallArgs.push_back(
435                            FP_EXT(VEXTRACT(pArg, C(i)), Type::getDoubleTy(JM()->mContext)));
436                    }
437                    printCallArgs.push_back(
438                        FP_EXT(VEXTRACT(pArg, C(i)), Type::getDoubleTy(JM()->mContext)));
439                }
440                else if ((tempStr[pos + 1] == 'd') && (pContainedType->isIntegerTy()))
441                {
442                    uint32_t i = 0;
443#if LLVM_VERSION_MAJOR >= 11
444                    for (; i < pVectorType->getNumElements() - 1; i++)
445#else
446                    for (; i < pType->getVectorNumElements() - 1; i++)
447#endif
448                    {
449                        tempStr.insert(pos, std::string("%d "));
450                        pos += 3;
451                        printCallArgs.push_back(
452                            S_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
453                    }
454                    printCallArgs.push_back(
455                        S_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
456                }
457                else if ((tempStr[pos + 1] == 'u') && (pContainedType->isIntegerTy()))
458                {
459                    uint32_t i = 0;
460#if LLVM_VERSION_MAJOR >= 11
461                    for (; i < pVectorType->getNumElements() - 1; i++)
462#else
463                    for (; i < pType->getVectorNumElements() - 1; i++)
464#endif
465                    {
466                        tempStr.insert(pos, std::string("%d "));
467                        pos += 3;
468                        printCallArgs.push_back(
469                            Z_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
470                    }
471                    printCallArgs.push_back(
472                        Z_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));
473                }
474            }
475            else
476            {
477                if (toupper(tempStr[pos + 1]) == 'X')
478                {
479                    tempStr[pos] = '0';
480                    tempStr.insert(pos + 1, "x%08");
481                    printCallArgs.push_back(pArg);
482                    pos += 3;
483                }
484                // for %f we need to cast float Values to doubles so that they print out correctly
485                else if ((tempStr[pos + 1] == 'f') && (pType->isFloatTy()))
486                {
487                    printCallArgs.push_back(FP_EXT(pArg, Type::getDoubleTy(JM()->mContext)));
488                    pos++;
489                }
490                else
491                {
492                    printCallArgs.push_back(pArg);
493                }
494            }
495
496            // advance to the next argument
497            v++;
498            pos = tempStr.find('%', ++pos);
499        }
500
501        // create global variable constant string
502        Constant*       constString = ConstantDataArray::getString(JM()->mContext, tempStr, true);
503        GlobalVariable* gvPtr       = new GlobalVariable(
504            constString->getType(), true, GlobalValue::InternalLinkage, constString, "printStr");
505        JM()->mpCurrentModule->getGlobalList().push_back(gvPtr);
506
507        // get a pointer to the first character in the constant string array
508        std::vector<Constant*> geplist{C(0), C(0)};
509        Constant* strGEP = ConstantExpr::getGetElementPtr(nullptr, gvPtr, geplist, false);
510
511        // insert the pointer to the format string in the argument vector
512        printCallArgs[0] = strGEP;
513
514        // get pointer to CallPrint function and insert decl into the module if needed
515        std::vector<Type*> args;
516        args.push_back(PointerType::get(mInt8Ty, 0));
517        FunctionType* callPrintTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, true);
518        Function*     callPrintFn =
519#if LLVM_VERSION_MAJOR >= 9
520            cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy).getCallee());
521#else
522            cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy));
523#endif
524
525        // if we haven't yet added the symbol to the symbol table
526        if ((sys::DynamicLibrary::SearchForAddressOfSymbol("CallPrint")) == nullptr)
527        {
528            sys::DynamicLibrary::AddSymbol("CallPrint", (void*)&CallPrint);
529        }
530
531        // insert a call to CallPrint
532        return CALLA(callPrintFn, printCallArgs);
533    }
534
535    //////////////////////////////////////////////////////////////////////////
536    /// @brief Wrapper around PRINT with initializer list.
537    CallInst* Builder::PRINT(const std::string& printStr) { return PRINT(printStr, {}); }
538
539    Value* Builder::EXTRACT_16(Value* x, uint32_t imm)
540    {
541        if (imm == 0)
542        {
543            return VSHUFFLE(x, UndefValue::get(x->getType()), {0, 1, 2, 3, 4, 5, 6, 7});
544        }
545        else
546        {
547            return VSHUFFLE(x, UndefValue::get(x->getType()), {8, 9, 10, 11, 12, 13, 14, 15});
548        }
549    }
550
551    Value* Builder::JOIN_16(Value* a, Value* b)
552    {
553        return VSHUFFLE(a, b, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
554    }
555
556    //////////////////////////////////////////////////////////////////////////
557    /// @brief convert x86 <N x float> mask to llvm <N x i1> mask
558    Value* Builder::MASK(Value* vmask)
559    {
560        Value* src = BITCAST(vmask, mSimdInt32Ty);
561        return ICMP_SLT(src, VIMMED1(0));
562    }
563
564    Value* Builder::MASK_16(Value* vmask)
565    {
566        Value* src = BITCAST(vmask, mSimd16Int32Ty);
567        return ICMP_SLT(src, VIMMED1_16(0));
568    }
569
570    //////////////////////////////////////////////////////////////////////////
571    /// @brief convert llvm <N x i1> mask to x86 <N x i32> mask
572    Value* Builder::VMASK(Value* mask) { return S_EXT(mask, mSimdInt32Ty); }
573
574    Value* Builder::VMASK_16(Value* mask) { return S_EXT(mask, mSimd16Int32Ty); }
575
576    /// @brief Convert <Nxi1> llvm mask to integer
577    Value* Builder::VMOVMSK(Value* mask)
578    {
579#if LLVM_VERSION_MAJOR >= 11
580#if LLVM_VERSION_MAJOR >= 12
581        FixedVectorType* pVectorType = cast<FixedVectorType>(mask->getType());
582#else
583        VectorType* pVectorType = cast<VectorType>(mask->getType());
584#endif
585        SWR_ASSERT(pVectorType->getElementType() == mInt1Ty);
586        uint32_t numLanes = pVectorType->getNumElements();
587#else
588        SWR_ASSERT(mask->getType()->getVectorElementType() == mInt1Ty);
589        uint32_t numLanes = mask->getType()->getVectorNumElements();
590#endif
591        Value*   i32Result;
592        if (numLanes == 8)
593        {
594            i32Result = BITCAST(mask, mInt8Ty);
595        }
596        else if (numLanes == 16)
597        {
598            i32Result = BITCAST(mask, mInt16Ty);
599        }
600        else
601        {
602            SWR_ASSERT("Unsupported vector width");
603            i32Result = BITCAST(mask, mInt8Ty);
604        }
605        return Z_EXT(i32Result, mInt32Ty);
606    }
607
608    //////////////////////////////////////////////////////////////////////////
609    /// @brief Generate a VPSHUFB operation in LLVM IR.  If not
610    /// supported on the underlying platform, emulate it
611    /// @param a - 256bit SIMD(32x8bit) of 8bit integer values
612    /// @param b - 256bit SIMD(32x8bit) of 8bit integer mask values
613    /// Byte masks in lower 128 lane of b selects 8 bit values from lower
614    /// 128bits of a, and vice versa for the upper lanes.  If the mask
615    /// value is negative, '0' is inserted.
616    Value* Builder::PSHUFB(Value* a, Value* b)
617    {
618        Value* res;
619        // use avx2 pshufb instruction if available
620        if (JM()->mArch.AVX2())
621        {
622            res = VPSHUFB(a, b);
623        }
624        else
625        {
626            Constant* cB = dyn_cast<Constant>(b);
627            assert(cB != nullptr);
628            // number of 8 bit elements in b
629#if LLVM_VERSION_MAJOR >= 12
630            uint32_t numElms = cast<FixedVectorType>(cB->getType())->getNumElements();
631#else
632            uint32_t numElms = cast<VectorType>(cB->getType())->getNumElements();
633#endif
634            // output vector
635            Value* vShuf = UndefValue::get(getVectorType(mInt8Ty, numElms));
636
637            // insert an 8 bit value from the high and low lanes of a per loop iteration
638            numElms /= 2;
639            for (uint32_t i = 0; i < numElms; i++)
640            {
641                ConstantInt* cLow128b  = cast<ConstantInt>(cB->getAggregateElement(i));
642                ConstantInt* cHigh128b = cast<ConstantInt>(cB->getAggregateElement(i + numElms));
643
644                // extract values from constant mask
645                char valLow128bLane  = (char)(cLow128b->getSExtValue());
646                char valHigh128bLane = (char)(cHigh128b->getSExtValue());
647
648                Value* insertValLow128b;
649                Value* insertValHigh128b;
650
651                // if the mask value is negative, insert a '0' in the respective output position
652                // otherwise, lookup the value at mask position (bits 3..0 of the respective mask
653                // byte) in a and insert in output vector
654                insertValLow128b =
655                    (valLow128bLane < 0) ? C((char)0) : VEXTRACT(a, C((valLow128bLane & 0xF)));
656                insertValHigh128b = (valHigh128bLane < 0)
657                                        ? C((char)0)
658                                        : VEXTRACT(a, C((valHigh128bLane & 0xF) + numElms));
659
660                vShuf = VINSERT(vShuf, insertValLow128b, i);
661                vShuf = VINSERT(vShuf, insertValHigh128b, (i + numElms));
662            }
663            res = vShuf;
664        }
665        return res;
666    }
667
668    //////////////////////////////////////////////////////////////////////////
669    /// @brief Generate a VPSHUFB operation (sign extend 8 8bit values to 32
670    /// bits)in LLVM IR.  If not supported on the underlying platform, emulate it
671    /// @param a - 128bit SIMD lane(16x8bit) of 8bit integer values.  Only
672    /// lower 8 values are used.
673    Value* Builder::PMOVSXBD(Value* a)
674    {
675        // VPMOVSXBD output type
676        Type* v8x32Ty = getVectorType(mInt32Ty, 8);
677        // Extract 8 values from 128bit lane and sign extend
678        return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);
679    }
680
681    //////////////////////////////////////////////////////////////////////////
682    /// @brief Generate a VPSHUFB operation (sign extend 8 16bit values to 32
683    /// bits)in LLVM IR.  If not supported on the underlying platform, emulate it
684    /// @param a - 128bit SIMD lane(8x16bit) of 16bit integer values.
685    Value* Builder::PMOVSXWD(Value* a)
686    {
687        // VPMOVSXWD output type
688        Type* v8x32Ty = getVectorType(mInt32Ty, 8);
689        // Extract 8 values from 128bit lane and sign extend
690        return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);
691    }
692
693    //////////////////////////////////////////////////////////////////////////
694    /// @brief Generate a VCVTPH2PS operation (float16->float32 conversion)
695    /// in LLVM IR.  If not supported on the underlying platform, emulate it
696    /// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.
697    Value* Builder::CVTPH2PS(Value* a, const llvm::Twine& name)
698    {
699        // Bitcast Nxint16 to Nxhalf
700#if LLVM_VERSION_MAJOR >= 12
701        uint32_t numElems = cast<FixedVectorType>(a->getType())->getNumElements();
702#elif LLVM_VERSION_MAJOR >= 11
703        uint32_t numElems = cast<VectorType>(a->getType())->getNumElements();
704#else
705        uint32_t numElems = a->getType()->getVectorNumElements();
706#endif
707        Value*   input    = BITCAST(a, getVectorType(mFP16Ty, numElems));
708
709        return FP_EXT(input, getVectorType(mFP32Ty, numElems), name);
710    }
711
712    //////////////////////////////////////////////////////////////////////////
713    /// @brief Generate a VCVTPS2PH operation (float32->float16 conversion)
714    /// in LLVM IR.  If not supported on the underlying platform, emulate it
715    /// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.
716    Value* Builder::CVTPS2PH(Value* a, Value* rounding)
717    {
718        if (JM()->mArch.F16C())
719        {
720            return VCVTPS2PH(a, rounding);
721        }
722        else
723        {
724            // call scalar C function for now
725            FunctionType* pFuncTy   = FunctionType::get(mInt16Ty, mFP32Ty);
726            Function*     pCvtPs2Ph = cast<Function>(
727#if LLVM_VERSION_MAJOR >= 9
728                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy).getCallee());
729#else
730                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy));
731#endif
732
733            if (sys::DynamicLibrary::SearchForAddressOfSymbol("ConvertFloat32ToFloat16") == nullptr)
734            {
735                sys::DynamicLibrary::AddSymbol("ConvertFloat32ToFloat16",
736                                               (void*)&ConvertFloat32ToFloat16);
737            }
738
739            Value* pResult = UndefValue::get(mSimdInt16Ty);
740            for (uint32_t i = 0; i < mVWidth; ++i)
741            {
742                Value* pSrc  = VEXTRACT(a, C(i));
743                Value* pConv = CALL(pCvtPs2Ph, std::initializer_list<Value*>{pSrc});
744                pResult      = VINSERT(pResult, pConv, C(i));
745            }
746
747            return pResult;
748        }
749    }
750
751    Value* Builder::PMAXSD(Value* a, Value* b)
752    {
753        Value* cmp = ICMP_SGT(a, b);
754        return SELECT(cmp, a, b);
755    }
756
757    Value* Builder::PMINSD(Value* a, Value* b)
758    {
759        Value* cmp = ICMP_SLT(a, b);
760        return SELECT(cmp, a, b);
761    }
762
763    Value* Builder::PMAXUD(Value* a, Value* b)
764    {
765        Value* cmp = ICMP_UGT(a, b);
766        return SELECT(cmp, a, b);
767    }
768
769    Value* Builder::PMINUD(Value* a, Value* b)
770    {
771        Value* cmp = ICMP_ULT(a, b);
772        return SELECT(cmp, a, b);
773    }
774
775    // Helper function to create alloca in entry block of function
776    Value* Builder::CreateEntryAlloca(Function* pFunc, Type* pType)
777    {
778        auto saveIP = IRB()->saveIP();
779        IRB()->SetInsertPoint(&pFunc->getEntryBlock(), pFunc->getEntryBlock().begin());
780        Value* pAlloca = ALLOCA(pType);
781        if (saveIP.isSet())
782            IRB()->restoreIP(saveIP);
783        return pAlloca;
784    }
785
786    Value* Builder::CreateEntryAlloca(Function* pFunc, Type* pType, Value* pArraySize)
787    {
788        auto saveIP = IRB()->saveIP();
789        IRB()->SetInsertPoint(&pFunc->getEntryBlock(), pFunc->getEntryBlock().begin());
790        Value* pAlloca = ALLOCA(pType, pArraySize);
791        if (saveIP.isSet())
792            IRB()->restoreIP(saveIP);
793        return pAlloca;
794    }
795
796    Value* Builder::VABSPS(Value* a)
797    {
798        Value* asInt  = BITCAST(a, mSimdInt32Ty);
799        Value* result = BITCAST(AND(asInt, VIMMED1(0x7fffffff)), mSimdFP32Ty);
800        return result;
801    }
802
803    Value* Builder::ICLAMP(Value* src, Value* low, Value* high, const llvm::Twine& name)
804    {
805        Value* lowCmp = ICMP_SLT(src, low);
806        Value* ret    = SELECT(lowCmp, low, src);
807
808        Value* highCmp = ICMP_SGT(ret, high);
809        ret            = SELECT(highCmp, high, ret, name);
810
811        return ret;
812    }
813
814    Value* Builder::FCLAMP(Value* src, Value* low, Value* high)
815    {
816        Value* lowCmp = FCMP_OLT(src, low);
817        Value* ret    = SELECT(lowCmp, low, src);
818
819        Value* highCmp = FCMP_OGT(ret, high);
820        ret            = SELECT(highCmp, high, ret);
821
822        return ret;
823    }
824
825    Value* Builder::FCLAMP(Value* src, float low, float high)
826    {
827        Value* result = VMAXPS(src, VIMMED1(low));
828        result        = VMINPS(result, VIMMED1(high));
829
830        return result;
831    }
832
833    Value* Builder::FMADDPS(Value* a, Value* b, Value* c)
834    {
835        Value* vOut;
836        // This maps to LLVM fmuladd intrinsic
837        vOut = VFMADDPS(a, b, c);
838        return vOut;
839    }
840
841    //////////////////////////////////////////////////////////////////////////
842    /// @brief pop count on vector mask (e.g. <8 x i1>)
843    Value* Builder::VPOPCNT(Value* a) { return POPCNT(VMOVMSK(a)); }
844
845    //////////////////////////////////////////////////////////////////////////
846    /// @brief Float / Fixed-point conversions
847    //////////////////////////////////////////////////////////////////////////
848    Value* Builder::VCVT_F32_FIXED_SI(Value*             vFloat,
849                                      uint32_t           numIntBits,
850                                      uint32_t           numFracBits,
851                                      const llvm::Twine& name)
852    {
853        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
854        Value* fixed = nullptr;
855
856#if 0   // This doesn't work for negative numbers!!
857        {
858            fixed = FP_TO_SI(VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),
859                                    C(_MM_FROUND_TO_NEAREST_INT)),
860                             mSimdInt32Ty);
861        }
862        else
863#endif
864        {
865            // Do round to nearest int on fractional bits first
866            // Not entirely perfect for negative numbers, but close enough
867            vFloat = VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),
868                            C(_MM_FROUND_TO_NEAREST_INT));
869            vFloat = FMUL(vFloat, VIMMED1(1.0f / float(1 << numFracBits)));
870
871            // TODO: Handle INF, NAN, overflow / underflow, etc.
872
873            Value* vSgn      = FCMP_OLT(vFloat, VIMMED1(0.0f));
874            Value* vFloatInt = BITCAST(vFloat, mSimdInt32Ty);
875            Value* vFixed    = AND(vFloatInt, VIMMED1((1 << 23) - 1));
876            vFixed           = OR(vFixed, VIMMED1(1 << 23));
877            vFixed           = SELECT(vSgn, NEG(vFixed), vFixed);
878
879            Value* vExp = LSHR(SHL(vFloatInt, VIMMED1(1)), VIMMED1(24));
880            vExp        = SUB(vExp, VIMMED1(127));
881
882            Value* vExtraBits = SUB(VIMMED1(23 - numFracBits), vExp);
883
884            fixed = ASHR(vFixed, vExtraBits, name);
885        }
886
887        return fixed;
888    }
889
890    Value* Builder::VCVT_FIXED_SI_F32(Value*             vFixed,
891                                      uint32_t           numIntBits,
892                                      uint32_t           numFracBits,
893                                      const llvm::Twine& name)
894    {
895        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
896        uint32_t extraBits = 32 - numIntBits - numFracBits;
897        if (numIntBits && extraBits)
898        {
899            // Sign extend
900            Value* shftAmt = VIMMED1(extraBits);
901            vFixed         = ASHR(SHL(vFixed, shftAmt), shftAmt);
902        }
903
904        Value* fVal  = VIMMED1(0.0f);
905        Value* fFrac = VIMMED1(0.0f);
906        if (numIntBits)
907        {
908            fVal = SI_TO_FP(ASHR(vFixed, VIMMED1(numFracBits)), mSimdFP32Ty, name);
909        }
910
911        if (numFracBits)
912        {
913            fFrac = UI_TO_FP(AND(vFixed, VIMMED1((1 << numFracBits) - 1)), mSimdFP32Ty);
914            fFrac = FDIV(fFrac, VIMMED1(float(1 << numFracBits)), name);
915        }
916
917        return FADD(fVal, fFrac, name);
918    }
919
920    Value* Builder::VCVT_F32_FIXED_UI(Value*             vFloat,
921                                      uint32_t           numIntBits,
922                                      uint32_t           numFracBits,
923                                      const llvm::Twine& name)
924    {
925        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
926        Value* fixed = nullptr;
927#if 1   // KNOB_SIM_FAST_MATH?  Below works correctly from a precision
928        // standpoint...
929        {
930            fixed = FP_TO_UI(VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),
931                                    C(_MM_FROUND_TO_NEAREST_INT)),
932                             mSimdInt32Ty);
933        }
934#else
935        {
936            // Do round to nearest int on fractional bits first
937            vFloat = VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),
938                            C(_MM_FROUND_TO_NEAREST_INT));
939            vFloat = FMUL(vFloat, VIMMED1(1.0f / float(1 << numFracBits)));
940
941            // TODO: Handle INF, NAN, overflow / underflow, etc.
942
943            Value* vSgn      = FCMP_OLT(vFloat, VIMMED1(0.0f));
944            Value* vFloatInt = BITCAST(vFloat, mSimdInt32Ty);
945            Value* vFixed    = AND(vFloatInt, VIMMED1((1 << 23) - 1));
946            vFixed           = OR(vFixed, VIMMED1(1 << 23));
947
948            Value* vExp = LSHR(SHL(vFloatInt, VIMMED1(1)), VIMMED1(24));
949            vExp        = SUB(vExp, VIMMED1(127));
950
951            Value* vExtraBits = SUB(VIMMED1(23 - numFracBits), vExp);
952
953            fixed = LSHR(vFixed, vExtraBits, name);
954        }
955#endif
956        return fixed;
957    }
958
959    Value* Builder::VCVT_FIXED_UI_F32(Value*             vFixed,
960                                      uint32_t           numIntBits,
961                                      uint32_t           numFracBits,
962                                      const llvm::Twine& name)
963    {
964        SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");
965        uint32_t extraBits = 32 - numIntBits - numFracBits;
966        if (numIntBits && extraBits)
967        {
968            // Sign extend
969            Value* shftAmt = VIMMED1(extraBits);
970            vFixed         = ASHR(SHL(vFixed, shftAmt), shftAmt);
971        }
972
973        Value* fVal  = VIMMED1(0.0f);
974        Value* fFrac = VIMMED1(0.0f);
975        if (numIntBits)
976        {
977            fVal = UI_TO_FP(LSHR(vFixed, VIMMED1(numFracBits)), mSimdFP32Ty, name);
978        }
979
980        if (numFracBits)
981        {
982            fFrac = UI_TO_FP(AND(vFixed, VIMMED1((1 << numFracBits) - 1)), mSimdFP32Ty);
983            fFrac = FDIV(fFrac, VIMMED1(float(1 << numFracBits)), name);
984        }
985
986        return FADD(fVal, fFrac, name);
987    }
988
989    //////////////////////////////////////////////////////////////////////////
990    /// @brief C functions called by LLVM IR
991    //////////////////////////////////////////////////////////////////////////
992
993    Value* Builder::VEXTRACTI128(Value* a, Constant* imm8)
994    {
995        bool                      flag = !imm8->isZeroValue();
996        SmallVector<Constant*, 8> idx;
997        for (unsigned i = 0; i < mVWidth / 2; i++)
998        {
999            idx.push_back(C(flag ? i + mVWidth / 2 : i));
1000        }
1001        return VSHUFFLE(a, VUNDEF_I(), ConstantVector::get(idx));
1002    }
1003
1004    Value* Builder::VINSERTI128(Value* a, Value* b, Constant* imm8)
1005    {
1006        bool                      flag = !imm8->isZeroValue();
1007        SmallVector<Constant*, 8> idx;
1008        for (unsigned i = 0; i < mVWidth; i++)
1009        {
1010            idx.push_back(C(i));
1011        }
1012        Value* inter = VSHUFFLE(b, VUNDEF_I(), ConstantVector::get(idx));
1013
1014        SmallVector<Constant*, 8> idx2;
1015        for (unsigned i = 0; i < mVWidth / 2; i++)
1016        {
1017            idx2.push_back(C(flag ? i : i + mVWidth));
1018        }
1019        for (unsigned i = mVWidth / 2; i < mVWidth; i++)
1020        {
1021            idx2.push_back(C(flag ? i + mVWidth / 2 : i));
1022        }
1023        return VSHUFFLE(a, inter, ConstantVector::get(idx2));
1024    }
1025
1026    // rdtsc buckets macros
1027    void Builder::RDTSC_START(Value* pBucketMgr, Value* pId)
1028    {
1029        // @todo due to an issue with thread local storage propagation in llvm, we can only safely
1030        // call into buckets framework when single threaded
1031        if (KNOB_SINGLE_THREADED)
1032        {
1033            std::vector<Type*> args{
1034                PointerType::get(mInt32Ty, 0), // pBucketMgr
1035                mInt32Ty                       // id
1036            };
1037
1038            FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);
1039            Function*     pFunc   = cast<Function>(
1040#if LLVM_VERSION_MAJOR >= 9
1041                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy).getCallee());
1042#else
1043                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy));
1044#endif
1045            if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StartBucket") ==
1046                nullptr)
1047            {
1048                sys::DynamicLibrary::AddSymbol("BucketManager_StartBucket",
1049                                               (void*)&BucketManager_StartBucket);
1050            }
1051
1052            CALL(pFunc, {pBucketMgr, pId});
1053        }
1054    }
1055
1056    void Builder::RDTSC_STOP(Value* pBucketMgr, Value* pId)
1057    {
1058        // @todo due to an issue with thread local storage propagation in llvm, we can only safely
1059        // call into buckets framework when single threaded
1060        if (KNOB_SINGLE_THREADED)
1061        {
1062            std::vector<Type*> args{
1063                PointerType::get(mInt32Ty, 0), // pBucketMgr
1064                mInt32Ty                       // id
1065            };
1066
1067            FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);
1068            Function*     pFunc   = cast<Function>(
1069#if LLVM_VERSION_MAJOR >= 9
1070                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy).getCallee());
1071#else
1072                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy));
1073#endif
1074            if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StopBucket") ==
1075                nullptr)
1076            {
1077                sys::DynamicLibrary::AddSymbol("BucketManager_StopBucket",
1078                                               (void*)&BucketManager_StopBucket);
1079            }
1080
1081            CALL(pFunc, {pBucketMgr, pId});
1082        }
1083    }
1084
1085    uint32_t Builder::GetTypeSize(Type* pType)
1086    {
1087        if (pType->isStructTy())
1088        {
1089            uint32_t numElems = pType->getStructNumElements();
1090            Type*    pElemTy  = pType->getStructElementType(0);
1091            return numElems * GetTypeSize(pElemTy);
1092        }
1093
1094        if (pType->isArrayTy())
1095        {
1096            uint32_t numElems = pType->getArrayNumElements();
1097            Type*    pElemTy  = pType->getArrayElementType();
1098            return numElems * GetTypeSize(pElemTy);
1099        }
1100
1101        if (pType->isIntegerTy())
1102        {
1103            uint32_t bitSize = pType->getIntegerBitWidth();
1104            return bitSize / 8;
1105        }
1106
1107        if (pType->isFloatTy())
1108        {
1109            return 4;
1110        }
1111
1112        if (pType->isHalfTy())
1113        {
1114            return 2;
1115        }
1116
1117        if (pType->isDoubleTy())
1118        {
1119            return 8;
1120        }
1121
1122        SWR_ASSERT(false, "Unimplemented type.");
1123        return 0;
1124    }
1125} // namespace SwrJit
1126