1/****************************************************************************
2 * Copyright (C) 2014-2015 Intel Corporation.   All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * @file streamout_jit.cpp
24 *
25 * @brief Implementation of the streamout jitter
26 *
27 * Notes:
28 *
29 ******************************************************************************/
30#include "jit_pch.hpp"
31#include "builder_gfx_mem.h"
32#include "jit_api.h"
33#include "streamout_jit.h"
34#include "gen_state_llvm.h"
35#include "functionpasses/passes.h"
36
37using namespace llvm;
38using namespace SwrJit;
39
40//////////////////////////////////////////////////////////////////////////
41/// Interface to Jitting a fetch shader
42//////////////////////////////////////////////////////////////////////////
43struct StreamOutJit : public BuilderGfxMem
44{
45    StreamOutJit(JitManager* pJitMgr) : BuilderGfxMem(pJitMgr){};
46
47    // returns pointer to SWR_STREAMOUT_BUFFER
48    Value* getSOBuffer(Value* pSoCtx, uint32_t buffer)
49    {
50        return LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_pBuffer, buffer});
51    }
52
53    //////////////////////////////////////////////////////////////////////////
54    // @brief checks if streamout buffer is oob
55    // @return <i1> true/false
56    Value* oob(const STREAMOUT_COMPILE_STATE& state, Value* pSoCtx, uint32_t buffer)
57    {
58        Value* returnMask = C(false);
59
60        Value* pBuf = getSOBuffer(pSoCtx, buffer);
61
62        // load enable
63        // @todo bool data types should generate <i1> llvm type
64        Value* enabled = TRUNC(LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_enable}), IRB()->getInt1Ty());
65
66        // load buffer size
67        Value* bufferSize = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_bufferSize});
68
69        // load current streamOffset
70        Value* streamOffset = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
71
72        // load buffer pitch
73        Value* pitch = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pitch});
74
75        // buffer is considered oob if in use in a decl but not enabled
76        returnMask = OR(returnMask, NOT(enabled));
77
78        // buffer is oob if cannot fit a prims worth of verts
79        Value* newOffset = ADD(streamOffset, MUL(pitch, C(state.numVertsPerPrim)));
80        returnMask       = OR(returnMask, ICMP_SGT(newOffset, bufferSize));
81
82        return returnMask;
83    }
84
85    //////////////////////////////////////////////////////////////////////////
86    // @brief converts scalar bitmask to <4 x i32> suitable for shuffle vector,
87    //        packing the active mask bits
88    //        ex. bitmask 0011 -> (0, 1, 0, 0)
89    //            bitmask 1000 -> (3, 0, 0, 0)
90    //            bitmask 1100 -> (2, 3, 0, 0)
91    Value* PackMask(uint32_t bitmask)
92    {
93        std::vector<Constant*> indices(4, C(0));
94        DWORD                  index;
95        uint32_t               elem = 0;
96        while (_BitScanForward(&index, bitmask))
97        {
98            indices[elem++] = C((int)index);
99            bitmask &= ~(1 << index);
100        }
101
102        return ConstantVector::get(indices);
103    }
104
105    //////////////////////////////////////////////////////////////////////////
106    // @brief convert scalar bitmask to <4xfloat> bitmask
107    Value* ToMask(uint32_t bitmask)
108    {
109        std::vector<Constant*> indices;
110        for (uint32_t i = 0; i < 4; ++i)
111        {
112            if (bitmask & (1 << i))
113            {
114                indices.push_back(C(true));
115            }
116            else
117            {
118                indices.push_back(C(false));
119            }
120        }
121        return ConstantVector::get(indices);
122    }
123
124    //////////////////////////////////////////////////////////////////////////
125    // @brief processes a single decl from the streamout stream. Reads 4 components from the input
126    //        stream and writes N components to the output buffer given the componentMask or if
127    //        a hole, just increments the buffer pointer
128    // @param pStream - pointer to current attribute
129    // @param pOutBuffers - pointers to the current location of each output buffer
130    // @param decl - input decl
131    void buildDecl(Value* pStream, Value* pOutBuffers[4], const STREAMOUT_DECL& decl)
132    {
133        uint32_t numComponents = _mm_popcnt_u32(decl.componentMask);
134        uint32_t packedMask    = (1 << numComponents) - 1;
135        if (!decl.hole)
136        {
137            // increment stream pointer to correct slot
138            Value* pAttrib = GEP(pStream, C(4 * decl.attribSlot));
139
140            // load 4 components from stream
141            Type* simd4Ty    = VectorType::get(IRB()->getFloatTy(), 4);
142            Type* simd4PtrTy = PointerType::get(simd4Ty, 0);
143            pAttrib          = BITCAST(pAttrib, simd4PtrTy);
144            Value* vattrib   = LOAD(pAttrib);
145
146            // shuffle/pack enabled components
147            Value* vpackedAttrib = VSHUFFLE(vattrib, vattrib, PackMask(decl.componentMask));
148
149            // store to output buffer
150            // cast SO buffer to i8*, needed by maskstore
151            Value* pOut = BITCAST(pOutBuffers[decl.bufferIndex], PointerType::get(simd4Ty, 0));
152
153            // cast input to <4xfloat>
154            Value* src = BITCAST(vpackedAttrib, simd4Ty);
155
156            // cast mask to <4xi1>
157            Value* mask = ToMask(packedMask);
158            MASKED_STORE(src, pOut, 4, mask, PointerType::get(simd4Ty, 0), JIT_MEM_CLIENT::GFX_MEM_CLIENT_STREAMOUT);
159        }
160
161        // increment SO buffer
162        pOutBuffers[decl.bufferIndex] = GEP(pOutBuffers[decl.bufferIndex], C(numComponents));
163    }
164
165    //////////////////////////////////////////////////////////////////////////
166    // @brief builds a single vertex worth of data for the given stream
167    // @param streamState - state for this stream
168    // @param pCurVertex - pointer to src stream vertex data
169    // @param pOutBuffer - pointers to up to 4 SO buffers
170    void buildVertex(const STREAMOUT_STREAM& streamState, Value* pCurVertex, Value* pOutBuffer[4])
171    {
172        for (uint32_t d = 0; d < streamState.numDecls; ++d)
173        {
174            const STREAMOUT_DECL& decl = streamState.decl[d];
175            buildDecl(pCurVertex, pOutBuffer, decl);
176        }
177    }
178
179    void buildStream(const STREAMOUT_COMPILE_STATE& state,
180                     const STREAMOUT_STREAM&        streamState,
181                     Value*                         pSoCtx,
182                     BasicBlock*                    returnBB,
183                     Function*                      soFunc)
184    {
185        // get list of active SO buffers
186        std::unordered_set<uint32_t> activeSOBuffers;
187        for (uint32_t d = 0; d < streamState.numDecls; ++d)
188        {
189            const STREAMOUT_DECL& decl = streamState.decl[d];
190            activeSOBuffers.insert(decl.bufferIndex);
191        }
192
193        // always increment numPrimStorageNeeded
194        Value* numPrimStorageNeeded = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded});
195        numPrimStorageNeeded        = ADD(numPrimStorageNeeded, C(1));
196        STORE(numPrimStorageNeeded, pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded});
197
198        // check OOB on active SO buffers.  If any buffer is out of bound, don't write
199        // the primitive to any buffer
200        Value* oobMask = C(false);
201        for (uint32_t buffer : activeSOBuffers)
202        {
203            oobMask = OR(oobMask, oob(state, pSoCtx, buffer));
204        }
205
206        BasicBlock* validBB = BasicBlock::Create(JM()->mContext, "valid", soFunc);
207
208        // early out if OOB
209        COND_BR(oobMask, returnBB, validBB);
210
211        IRB()->SetInsertPoint(validBB);
212
213        Value* numPrimsWritten = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimsWritten});
214        numPrimsWritten        = ADD(numPrimsWritten, C(1));
215        STORE(numPrimsWritten, pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimsWritten});
216
217        // compute start pointer for each output buffer
218        Value* pOutBuffer[4];
219        Value* pOutBufferStartVertex[4];
220        Value* outBufferPitch[4];
221        for (uint32_t b : activeSOBuffers)
222        {
223            Value* pBuf              = getSOBuffer(pSoCtx, b);
224            Value* pData             = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pBuffer});
225            Value* streamOffset      = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
226            pOutBuffer[b] = GEP(pData, streamOffset, PointerType::get(IRB()->getInt32Ty(), 0));
227            pOutBufferStartVertex[b] = pOutBuffer[b];
228
229            outBufferPitch[b] = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pitch});
230        }
231
232        // loop over the vertices of the prim
233        Value* pStreamData = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_pPrimData});
234        for (uint32_t v = 0; v < state.numVertsPerPrim; ++v)
235        {
236            buildVertex(streamState, pStreamData, pOutBuffer);
237
238            // increment stream and output buffer pointers
239            // stream verts are always 32*4 dwords apart
240            pStreamData = GEP(pStreamData, C(SWR_VTX_NUM_SLOTS * 4));
241
242            // output buffers offset using pitch in buffer state
243            for (uint32_t b : activeSOBuffers)
244            {
245                pOutBufferStartVertex[b] = GEP(pOutBufferStartVertex[b], outBufferPitch[b]);
246                pOutBuffer[b]            = pOutBufferStartVertex[b];
247            }
248        }
249
250        // update each active buffer's streamOffset
251        for (uint32_t b : activeSOBuffers)
252        {
253            Value* pBuf         = getSOBuffer(pSoCtx, b);
254            Value* streamOffset = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
255            streamOffset = ADD(streamOffset, MUL(C(state.numVertsPerPrim), outBufferPitch[b]));
256            STORE(streamOffset, pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
257        }
258    }
259
260    Function* Create(const STREAMOUT_COMPILE_STATE& state)
261    {
262        std::stringstream fnName("SO_",
263                                 std::ios_base::in | std::ios_base::out | std::ios_base::ate);
264        fnName << ComputeCRC(0, &state, sizeof(state));
265
266        // SO function signature
267        // typedef void(__cdecl *PFN_SO_FUNC)(SimDrawContext, SWR_STREAMOUT_CONTEXT*)
268
269        Type* typeParam0;
270        typeParam0 = mInt8PtrTy;
271
272        std::vector<Type*> args{
273            typeParam0,
274            PointerType::get(Gen_SWR_STREAMOUT_CONTEXT(JM()), 0), // SWR_STREAMOUT_CONTEXT*
275        };
276
277        FunctionType* fTy    = FunctionType::get(IRB()->getVoidTy(), args, false);
278        Function*     soFunc = Function::Create(
279            fTy, GlobalValue::ExternalLinkage, fnName.str(), JM()->mpCurrentModule);
280
281        soFunc->getParent()->setModuleIdentifier(soFunc->getName());
282
283        // create return basic block
284        BasicBlock* entry    = BasicBlock::Create(JM()->mContext, "entry", soFunc);
285        BasicBlock* returnBB = BasicBlock::Create(JM()->mContext, "return", soFunc);
286
287        IRB()->SetInsertPoint(entry);
288
289        // arguments
290        auto   argitr = soFunc->arg_begin();
291
292        Value* privateContext = &*argitr++;
293        privateContext->setName("privateContext");
294        SetPrivateContext(privateContext);
295
296        Value* pSoCtx = &*argitr++;
297        pSoCtx->setName("pSoCtx");
298
299        const STREAMOUT_STREAM& streamState = state.stream;
300        buildStream(state, streamState, pSoCtx, returnBB, soFunc);
301
302        BR(returnBB);
303
304        IRB()->SetInsertPoint(returnBB);
305        RET_VOID();
306
307        JitManager::DumpToFile(soFunc, "SoFunc");
308
309        ::FunctionPassManager passes(JM()->mpCurrentModule);
310
311        passes.add(createBreakCriticalEdgesPass());
312        passes.add(createCFGSimplificationPass());
313        passes.add(createEarlyCSEPass());
314        passes.add(createPromoteMemoryToRegisterPass());
315        passes.add(createCFGSimplificationPass());
316        passes.add(createEarlyCSEPass());
317        passes.add(createInstructionCombiningPass());
318        passes.add(createConstantPropagationPass());
319        passes.add(createSCCPPass());
320        passes.add(createAggressiveDCEPass());
321
322        passes.add(createLowerX86Pass(this));
323
324        passes.run(*soFunc);
325
326        JitManager::DumpToFile(soFunc, "SoFunc_optimized");
327
328
329        return soFunc;
330    }
331};
332
333//////////////////////////////////////////////////////////////////////////
334/// @brief JITs from streamout shader IR
335/// @param hJitMgr - JitManager handle
336/// @param func   - LLVM function IR
337/// @return PFN_SO_FUNC - pointer to SOS function
338PFN_SO_FUNC JitStreamoutFunc(HANDLE hJitMgr, const HANDLE hFunc)
339{
340    llvm::Function* func    = (llvm::Function*)hFunc;
341    JitManager*     pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
342    PFN_SO_FUNC     pfnStreamOut;
343    pfnStreamOut = (PFN_SO_FUNC)(pJitMgr->mpExec->getFunctionAddress(func->getName().str()));
344    // MCJIT finalizes modules the first time you JIT code from them. After finalized, you cannot
345    // add new IR to the module
346    pJitMgr->mIsModuleFinalized = true;
347
348    pJitMgr->DumpAsm(func, "SoFunc_optimized");
349
350
351    return pfnStreamOut;
352}
353
354//////////////////////////////////////////////////////////////////////////
355/// @brief JIT compiles streamout shader
356/// @param hJitMgr - JitManager handle
357/// @param state   - SO state to build function from
358extern "C" PFN_SO_FUNC JITCALL JitCompileStreamout(HANDLE                         hJitMgr,
359                                                   const STREAMOUT_COMPILE_STATE& state)
360{
361    JitManager* pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
362
363    STREAMOUT_COMPILE_STATE soState = state;
364    if (soState.offsetAttribs)
365    {
366        for (uint32_t i = 0; i < soState.stream.numDecls; ++i)
367        {
368            soState.stream.decl[i].attribSlot -= soState.offsetAttribs;
369        }
370    }
371
372    pJitMgr->SetupNewModule();
373
374    StreamOutJit theJit(pJitMgr);
375    HANDLE       hFunc = theJit.Create(soState);
376
377    return JitStreamoutFunc(hJitMgr, hFunc);
378}
379