17ec681f3Smrg/*
27ec681f3Smrg * Copyright © Microsoft Corporation
37ec681f3Smrg *
47ec681f3Smrg * Permission is hereby granted, free of charge, to any person obtaining a
57ec681f3Smrg * copy of this software and associated documentation files (the "Software"),
67ec681f3Smrg * to deal in the Software without restriction, including without limitation
77ec681f3Smrg * the rights to use, copy, modify, merge, publish, distribute, sublicense,
87ec681f3Smrg * and/or sell copies of the Software, and to permit persons to whom the
97ec681f3Smrg * Software is furnished to do so, subject to the following conditions:
107ec681f3Smrg *
117ec681f3Smrg * The above copyright notice and this permission notice (including the next
127ec681f3Smrg * paragraph) shall be included in all copies or substantial portions of the
137ec681f3Smrg * Software.
147ec681f3Smrg *
157ec681f3Smrg * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
167ec681f3Smrg * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
177ec681f3Smrg * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
187ec681f3Smrg * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
197ec681f3Smrg * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
207ec681f3Smrg * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
217ec681f3Smrg * IN THE SOFTWARE.
227ec681f3Smrg */
237ec681f3Smrg
247ec681f3Smrg#include <stdio.h>
257ec681f3Smrg#include <stdint.h>
267ec681f3Smrg#include <stdexcept>
277ec681f3Smrg
287ec681f3Smrg#include <directx/d3d12.h>
297ec681f3Smrg#include <dxgi1_4.h>
307ec681f3Smrg#include <gtest/gtest.h>
317ec681f3Smrg#include <wrl.h>
327ec681f3Smrg
337ec681f3Smrg#include "util/u_debug.h"
347ec681f3Smrg#include "clc_compiler.h"
357ec681f3Smrg#include "compute_test.h"
367ec681f3Smrg#include "dxcapi.h"
377ec681f3Smrg
387ec681f3Smrg#include <spirv-tools/libspirv.hpp>
397ec681f3Smrg
407ec681f3Smrgusing std::runtime_error;
417ec681f3Smrgusing Microsoft::WRL::ComPtr;
427ec681f3Smrg
437ec681f3Smrgenum compute_test_debug_flags {
447ec681f3Smrg   COMPUTE_DEBUG_EXPERIMENTAL_SHADERS = 1 << 0,
457ec681f3Smrg   COMPUTE_DEBUG_USE_HW_D3D           = 1 << 1,
467ec681f3Smrg   COMPUTE_DEBUG_OPTIMIZE_LIBCLC      = 1 << 2,
477ec681f3Smrg   COMPUTE_DEBUG_SERIALIZE_LIBCLC     = 1 << 3,
487ec681f3Smrg};
497ec681f3Smrg
507ec681f3Smrgstatic const struct debug_named_value compute_debug_options[] = {
517ec681f3Smrg   { "experimental_shaders",  COMPUTE_DEBUG_EXPERIMENTAL_SHADERS, "Enable experimental shaders" },
527ec681f3Smrg   { "use_hw_d3d",            COMPUTE_DEBUG_USE_HW_D3D,           "Use a hardware D3D device"   },
537ec681f3Smrg   { "optimize_libclc",       COMPUTE_DEBUG_OPTIMIZE_LIBCLC,      "Optimize the clc_libclc before using it" },
547ec681f3Smrg   { "serialize_libclc",      COMPUTE_DEBUG_SERIALIZE_LIBCLC,     "Serialize and deserialize the clc_libclc" },
557ec681f3Smrg   DEBUG_NAMED_VALUE_END
567ec681f3Smrg};
577ec681f3Smrg
587ec681f3SmrgDEBUG_GET_ONCE_FLAGS_OPTION(debug_compute, "COMPUTE_TEST_DEBUG", compute_debug_options, 0)
597ec681f3Smrg
607ec681f3Smrgstatic void warning_callback(void *priv, const char *msg)
617ec681f3Smrg{
627ec681f3Smrg   fprintf(stderr, "WARNING: %s\n", msg);
637ec681f3Smrg}
647ec681f3Smrg
657ec681f3Smrgstatic void error_callback(void *priv, const char *msg)
667ec681f3Smrg{
677ec681f3Smrg   fprintf(stderr, "ERROR: %s\n", msg);
687ec681f3Smrg}
697ec681f3Smrg
707ec681f3Smrgstatic const struct clc_logger logger = {
717ec681f3Smrg   NULL,
727ec681f3Smrg   error_callback,
737ec681f3Smrg   warning_callback,
747ec681f3Smrg};
757ec681f3Smrg
767ec681f3Smrgvoid
777ec681f3SmrgComputeTest::enable_d3d12_debug_layer()
787ec681f3Smrg{
797ec681f3Smrg   HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
807ec681f3Smrg   if (!hD3D12Mod) {
817ec681f3Smrg      fprintf(stderr, "D3D12: failed to load D3D12.DLL\n");
827ec681f3Smrg      return;
837ec681f3Smrg   }
847ec681f3Smrg
857ec681f3Smrg   typedef HRESULT(WINAPI * PFN_D3D12_GET_DEBUG_INTERFACE)(REFIID riid,
867ec681f3Smrg                                                           void **ppFactory);
877ec681f3Smrg   PFN_D3D12_GET_DEBUG_INTERFACE D3D12GetDebugInterface = (PFN_D3D12_GET_DEBUG_INTERFACE)GetProcAddress(hD3D12Mod, "D3D12GetDebugInterface");
887ec681f3Smrg   if (!D3D12GetDebugInterface) {
897ec681f3Smrg      fprintf(stderr, "D3D12: failed to load D3D12GetDebugInterface from D3D12.DLL\n");
907ec681f3Smrg      return;
917ec681f3Smrg   }
927ec681f3Smrg
937ec681f3Smrg   ID3D12Debug *debug;
947ec681f3Smrg   if (FAILED(D3D12GetDebugInterface(__uuidof(ID3D12Debug), (void **)& debug))) {
957ec681f3Smrg      fprintf(stderr, "D3D12: D3D12GetDebugInterface failed\n");
967ec681f3Smrg      return;
977ec681f3Smrg   }
987ec681f3Smrg
997ec681f3Smrg   debug->EnableDebugLayer();
1007ec681f3Smrg}
1017ec681f3Smrg
1027ec681f3SmrgIDXGIFactory4 *
1037ec681f3SmrgComputeTest::get_dxgi_factory()
1047ec681f3Smrg{
1057ec681f3Smrg   static const GUID IID_IDXGIFactory4 = {
1067ec681f3Smrg      0x1bc6ea02, 0xef36, 0x464f,
1077ec681f3Smrg      { 0xbf, 0x0c, 0x21, 0xca, 0x39, 0xe5, 0x16, 0x8a }
1087ec681f3Smrg   };
1097ec681f3Smrg
1107ec681f3Smrg   typedef HRESULT(WINAPI * PFN_CREATE_DXGI_FACTORY)(REFIID riid,
1117ec681f3Smrg                                                     void **ppFactory);
1127ec681f3Smrg   PFN_CREATE_DXGI_FACTORY CreateDXGIFactory;
1137ec681f3Smrg
1147ec681f3Smrg   HMODULE hDXGIMod = LoadLibrary("DXGI.DLL");
1157ec681f3Smrg   if (!hDXGIMod)
1167ec681f3Smrg      throw runtime_error("Failed to load DXGI.DLL");
1177ec681f3Smrg
1187ec681f3Smrg   CreateDXGIFactory = (PFN_CREATE_DXGI_FACTORY)GetProcAddress(hDXGIMod, "CreateDXGIFactory");
1197ec681f3Smrg   if (!CreateDXGIFactory)
1207ec681f3Smrg      throw runtime_error("Failed to load CreateDXGIFactory from DXGI.DLL");
1217ec681f3Smrg
1227ec681f3Smrg   IDXGIFactory4 *factory = NULL;
1237ec681f3Smrg   HRESULT hr = CreateDXGIFactory(IID_IDXGIFactory4, (void **)&factory);
1247ec681f3Smrg   if (FAILED(hr))
1257ec681f3Smrg      throw runtime_error("CreateDXGIFactory failed");
1267ec681f3Smrg
1277ec681f3Smrg   return factory;
1287ec681f3Smrg}
1297ec681f3Smrg
1307ec681f3SmrgIDXGIAdapter1 *
1317ec681f3SmrgComputeTest::choose_adapter(IDXGIFactory4 *factory)
1327ec681f3Smrg{
1337ec681f3Smrg   IDXGIAdapter1 *ret;
1347ec681f3Smrg
1357ec681f3Smrg   if (debug_get_option_debug_compute() & COMPUTE_DEBUG_USE_HW_D3D) {
1367ec681f3Smrg      for (unsigned i = 0; SUCCEEDED(factory->EnumAdapters1(i, &ret)); i++) {
1377ec681f3Smrg         DXGI_ADAPTER_DESC1 desc;
1387ec681f3Smrg         ret->GetDesc1(&desc);
1397ec681f3Smrg         if (!(desc.Flags & D3D_DRIVER_TYPE_SOFTWARE))
1407ec681f3Smrg            return ret;
1417ec681f3Smrg      }
1427ec681f3Smrg      throw runtime_error("Failed to enum hardware adapter");
1437ec681f3Smrg   } else {
1447ec681f3Smrg      if (FAILED(factory->EnumWarpAdapter(__uuidof(IDXGIAdapter1),
1457ec681f3Smrg         (void **)& ret)))
1467ec681f3Smrg         throw runtime_error("Failed to enum warp adapter");
1477ec681f3Smrg      return ret;
1487ec681f3Smrg   }
1497ec681f3Smrg}
1507ec681f3Smrg
1517ec681f3SmrgID3D12Device *
1527ec681f3SmrgComputeTest::create_device(IDXGIAdapter1 *adapter)
1537ec681f3Smrg{
1547ec681f3Smrg   typedef HRESULT(WINAPI *PFN_D3D12CREATEDEVICE)(IUnknown *, D3D_FEATURE_LEVEL, REFIID, void **);
1557ec681f3Smrg   PFN_D3D12CREATEDEVICE D3D12CreateDevice;
1567ec681f3Smrg
1577ec681f3Smrg   HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
1587ec681f3Smrg   if (!hD3D12Mod)
1597ec681f3Smrg      throw runtime_error("failed to load D3D12.DLL");
1607ec681f3Smrg
1617ec681f3Smrg   if (debug_get_option_debug_compute() & COMPUTE_DEBUG_EXPERIMENTAL_SHADERS) {
1627ec681f3Smrg      typedef HRESULT(WINAPI *PFN_D3D12ENABLEEXPERIMENTALFEATURES)(UINT, const IID *, void *, UINT *);
1637ec681f3Smrg      PFN_D3D12ENABLEEXPERIMENTALFEATURES D3D12EnableExperimentalFeatures;
1647ec681f3Smrg      D3D12EnableExperimentalFeatures = (PFN_D3D12ENABLEEXPERIMENTALFEATURES)
1657ec681f3Smrg         GetProcAddress(hD3D12Mod, "D3D12EnableExperimentalFeatures");
1667ec681f3Smrg      if (FAILED(D3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModels, NULL, NULL)))
1677ec681f3Smrg         throw runtime_error("failed to enable experimental shader models");
1687ec681f3Smrg   }
1697ec681f3Smrg
1707ec681f3Smrg   D3D12CreateDevice = (PFN_D3D12CREATEDEVICE)GetProcAddress(hD3D12Mod, "D3D12CreateDevice");
1717ec681f3Smrg   if (!D3D12CreateDevice)
1727ec681f3Smrg      throw runtime_error("failed to load D3D12CreateDevice from D3D12.DLL");
1737ec681f3Smrg
1747ec681f3Smrg   ID3D12Device *dev;
1757ec681f3Smrg   if (FAILED(D3D12CreateDevice(adapter, D3D_FEATURE_LEVEL_12_0,
1767ec681f3Smrg       __uuidof(ID3D12Device), (void **)& dev)))
1777ec681f3Smrg      throw runtime_error("D3D12CreateDevice failed");
1787ec681f3Smrg
1797ec681f3Smrg   return dev;
1807ec681f3Smrg}
1817ec681f3Smrg
1827ec681f3SmrgComPtr<ID3D12RootSignature>
1837ec681f3SmrgComputeTest::create_root_signature(const ComputeTest::Resources &resources)
1847ec681f3Smrg{
1857ec681f3Smrg   D3D12_ROOT_PARAMETER1 root_param;
1867ec681f3Smrg   root_param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
1877ec681f3Smrg   root_param.DescriptorTable.NumDescriptorRanges = resources.ranges.size();
1887ec681f3Smrg   root_param.DescriptorTable.pDescriptorRanges = resources.ranges.data();
1897ec681f3Smrg   root_param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
1907ec681f3Smrg
1917ec681f3Smrg   D3D12_ROOT_SIGNATURE_DESC1 root_sig_desc;
1927ec681f3Smrg   root_sig_desc.NumParameters = 1;
1937ec681f3Smrg   root_sig_desc.pParameters = &root_param;
1947ec681f3Smrg   root_sig_desc.NumStaticSamplers = 0;
1957ec681f3Smrg   root_sig_desc.pStaticSamplers = NULL;
1967ec681f3Smrg   root_sig_desc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
1977ec681f3Smrg
1987ec681f3Smrg   D3D12_VERSIONED_ROOT_SIGNATURE_DESC versioned_desc;
1997ec681f3Smrg   versioned_desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
2007ec681f3Smrg   versioned_desc.Desc_1_1 = root_sig_desc;
2017ec681f3Smrg
2027ec681f3Smrg   ID3DBlob *sig, *error;
2037ec681f3Smrg   if (FAILED(D3D12SerializeVersionedRootSignature(&versioned_desc,
2047ec681f3Smrg       &sig, &error)))
2057ec681f3Smrg      throw runtime_error("D3D12SerializeVersionedRootSignature failed");
2067ec681f3Smrg
2077ec681f3Smrg   ComPtr<ID3D12RootSignature> ret;
2087ec681f3Smrg   if (FAILED(dev->CreateRootSignature(0,
2097ec681f3Smrg       sig->GetBufferPointer(),
2107ec681f3Smrg       sig->GetBufferSize(),
2117ec681f3Smrg       __uuidof(ret),
2127ec681f3Smrg       (void **)& ret)))
2137ec681f3Smrg      throw runtime_error("CreateRootSignature failed");
2147ec681f3Smrg
2157ec681f3Smrg   return ret;
2167ec681f3Smrg}
2177ec681f3Smrg
2187ec681f3SmrgComPtr<ID3D12PipelineState>
2197ec681f3SmrgComputeTest::create_pipeline_state(ComPtr<ID3D12RootSignature> &root_sig,
2207ec681f3Smrg                                   const struct clc_dxil_object &dxil)
2217ec681f3Smrg{
2227ec681f3Smrg   D3D12_COMPUTE_PIPELINE_STATE_DESC pipeline_desc = { root_sig.Get() };
2237ec681f3Smrg   pipeline_desc.CS.pShaderBytecode = dxil.binary.data;
2247ec681f3Smrg   pipeline_desc.CS.BytecodeLength = dxil.binary.size;
2257ec681f3Smrg
2267ec681f3Smrg   ComPtr<ID3D12PipelineState> pipeline_state;
2277ec681f3Smrg   if (FAILED(dev->CreateComputePipelineState(&pipeline_desc,
2287ec681f3Smrg                                              __uuidof(pipeline_state),
2297ec681f3Smrg                                              (void **)& pipeline_state)))
2307ec681f3Smrg      throw runtime_error("Failed to create pipeline state");
2317ec681f3Smrg   return pipeline_state;
2327ec681f3Smrg}
2337ec681f3Smrg
2347ec681f3SmrgComPtr<ID3D12Resource>
2357ec681f3SmrgComputeTest::create_buffer(int size, D3D12_HEAP_TYPE heap_type)
2367ec681f3Smrg{
2377ec681f3Smrg   D3D12_RESOURCE_DESC desc;
2387ec681f3Smrg   desc.Format = DXGI_FORMAT_UNKNOWN;
2397ec681f3Smrg   desc.Alignment = D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT;
2407ec681f3Smrg   desc.Width = size;
2417ec681f3Smrg   desc.Height = 1;
2427ec681f3Smrg   desc.DepthOrArraySize = 1;
2437ec681f3Smrg   desc.MipLevels = 1;
2447ec681f3Smrg   desc.SampleDesc.Count = 1;
2457ec681f3Smrg   desc.SampleDesc.Quality = 0;
2467ec681f3Smrg   desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
2477ec681f3Smrg   desc.Flags = heap_type == D3D12_HEAP_TYPE_DEFAULT ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS : D3D12_RESOURCE_FLAG_NONE;
2487ec681f3Smrg   desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
2497ec681f3Smrg
2507ec681f3Smrg   D3D12_HEAP_PROPERTIES heap_pris = dev->GetCustomHeapProperties(0, heap_type);
2517ec681f3Smrg
2527ec681f3Smrg   ComPtr<ID3D12Resource> res;
2537ec681f3Smrg   if (FAILED(dev->CreateCommittedResource(&heap_pris,
2547ec681f3Smrg       D3D12_HEAP_FLAG_NONE, &desc, D3D12_RESOURCE_STATE_COMMON,
2557ec681f3Smrg       NULL, __uuidof(ID3D12Resource), (void **)&res)))
2567ec681f3Smrg      throw runtime_error("CreateCommittedResource failed");
2577ec681f3Smrg
2587ec681f3Smrg   return res;
2597ec681f3Smrg}
2607ec681f3Smrg
2617ec681f3SmrgComPtr<ID3D12Resource>
2627ec681f3SmrgComputeTest::create_upload_buffer_with_data(const void *data, size_t size)
2637ec681f3Smrg{
2647ec681f3Smrg   auto upload_res = create_buffer(size, D3D12_HEAP_TYPE_UPLOAD);
2657ec681f3Smrg
2667ec681f3Smrg   void *ptr = NULL;
2677ec681f3Smrg   D3D12_RANGE res_range = { 0, (SIZE_T)size };
2687ec681f3Smrg   if (FAILED(upload_res->Map(0, &res_range, (void **)&ptr)))
2697ec681f3Smrg      throw runtime_error("Failed to map upload-buffer");
2707ec681f3Smrg   assert(ptr);
2717ec681f3Smrg   memcpy(ptr, data, size);
2727ec681f3Smrg   upload_res->Unmap(0, &res_range);
2737ec681f3Smrg   return upload_res;
2747ec681f3Smrg}
2757ec681f3Smrg
2767ec681f3SmrgComPtr<ID3D12Resource>
2777ec681f3SmrgComputeTest::create_sized_buffer_with_data(size_t buffer_size,
2787ec681f3Smrg                                           const void *data,
2797ec681f3Smrg                                           size_t data_size)
2807ec681f3Smrg{
2817ec681f3Smrg   auto upload_res = create_upload_buffer_with_data(data, data_size);
2827ec681f3Smrg
2837ec681f3Smrg   auto res = create_buffer(buffer_size, D3D12_HEAP_TYPE_DEFAULT);
2847ec681f3Smrg   resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_DEST);
2857ec681f3Smrg   cmdlist->CopyBufferRegion(res.Get(), 0, upload_res.Get(), 0, data_size);
2867ec681f3Smrg   resource_barrier(res, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_COMMON);
2877ec681f3Smrg   execute_cmdlist();
2887ec681f3Smrg
2897ec681f3Smrg   return res;
2907ec681f3Smrg}
2917ec681f3Smrg
2927ec681f3Smrgvoid
2937ec681f3SmrgComputeTest::get_buffer_data(ComPtr<ID3D12Resource> res,
2947ec681f3Smrg                             void *buf, size_t size)
2957ec681f3Smrg{
2967ec681f3Smrg   auto readback_res = create_buffer(align(size, 4), D3D12_HEAP_TYPE_READBACK);
2977ec681f3Smrg   resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_SOURCE);
2987ec681f3Smrg   cmdlist->CopyResource(readback_res.Get(), res.Get());
2997ec681f3Smrg   resource_barrier(res, D3D12_RESOURCE_STATE_COPY_SOURCE, D3D12_RESOURCE_STATE_COMMON);
3007ec681f3Smrg   execute_cmdlist();
3017ec681f3Smrg
3027ec681f3Smrg   void *ptr = NULL;
3037ec681f3Smrg   D3D12_RANGE res_range = { 0, size };
3047ec681f3Smrg   if (FAILED(readback_res->Map(0, &res_range, &ptr)))
3057ec681f3Smrg      throw runtime_error("Failed to map readback-buffer");
3067ec681f3Smrg
3077ec681f3Smrg   memcpy(buf, ptr, size);
3087ec681f3Smrg
3097ec681f3Smrg   D3D12_RANGE empty_range = { 0, 0 };
3107ec681f3Smrg   readback_res->Unmap(0, &empty_range);
3117ec681f3Smrg}
3127ec681f3Smrg
3137ec681f3Smrgvoid
3147ec681f3SmrgComputeTest::resource_barrier(ComPtr<ID3D12Resource> &res,
3157ec681f3Smrg                              D3D12_RESOURCE_STATES state_before,
3167ec681f3Smrg                              D3D12_RESOURCE_STATES state_after)
3177ec681f3Smrg{
3187ec681f3Smrg   D3D12_RESOURCE_BARRIER barrier;
3197ec681f3Smrg   barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
3207ec681f3Smrg   barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
3217ec681f3Smrg   barrier.Transition.pResource = res.Get();
3227ec681f3Smrg   barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
3237ec681f3Smrg   barrier.Transition.StateBefore = state_before;
3247ec681f3Smrg   barrier.Transition.StateAfter = state_after;
3257ec681f3Smrg   cmdlist->ResourceBarrier(1, &barrier);
3267ec681f3Smrg}
3277ec681f3Smrg
3287ec681f3Smrgvoid
3297ec681f3SmrgComputeTest::execute_cmdlist()
3307ec681f3Smrg{
3317ec681f3Smrg   if (FAILED(cmdlist->Close()))
3327ec681f3Smrg      throw runtime_error("Closing ID3D12GraphicsCommandList failed");
3337ec681f3Smrg
3347ec681f3Smrg   ID3D12CommandList *cmdlists[] = { cmdlist };
3357ec681f3Smrg   cmdqueue->ExecuteCommandLists(1, cmdlists);
3367ec681f3Smrg   cmdqueue_fence->SetEventOnCompletion(fence_value, event);
3377ec681f3Smrg   cmdqueue->Signal(cmdqueue_fence, fence_value);
3387ec681f3Smrg   fence_value++;
3397ec681f3Smrg   WaitForSingleObject(event, INFINITE);
3407ec681f3Smrg
3417ec681f3Smrg   if (FAILED(cmdalloc->Reset()))
3427ec681f3Smrg      throw runtime_error("resetting ID3D12CommandAllocator failed");
3437ec681f3Smrg
3447ec681f3Smrg   if (FAILED(cmdlist->Reset(cmdalloc, NULL)))
3457ec681f3Smrg      throw runtime_error("resetting ID3D12GraphicsCommandList failed");
3467ec681f3Smrg}
3477ec681f3Smrg
3487ec681f3Smrgvoid
3497ec681f3SmrgComputeTest::create_uav_buffer(ComPtr<ID3D12Resource> res,
3507ec681f3Smrg                               size_t width, size_t byte_stride,
3517ec681f3Smrg                               D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
3527ec681f3Smrg{
3537ec681f3Smrg   D3D12_UNORDERED_ACCESS_VIEW_DESC uav_desc;
3547ec681f3Smrg   uav_desc.Format = DXGI_FORMAT_R32_TYPELESS;
3557ec681f3Smrg   uav_desc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
3567ec681f3Smrg   uav_desc.Buffer.FirstElement = 0;
3577ec681f3Smrg   uav_desc.Buffer.NumElements = DIV_ROUND_UP(width * byte_stride, 4);
3587ec681f3Smrg   uav_desc.Buffer.StructureByteStride = 0;
3597ec681f3Smrg   uav_desc.Buffer.CounterOffsetInBytes = 0;
3607ec681f3Smrg   uav_desc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
3617ec681f3Smrg
3627ec681f3Smrg   dev->CreateUnorderedAccessView(res.Get(), NULL, &uav_desc, cpu_handle);
3637ec681f3Smrg}
3647ec681f3Smrg
3657ec681f3Smrgvoid
3667ec681f3SmrgComputeTest::create_cbv(ComPtr<ID3D12Resource> res, size_t size,
3677ec681f3Smrg                        D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
3687ec681f3Smrg{
3697ec681f3Smrg   D3D12_CONSTANT_BUFFER_VIEW_DESC cbv_desc;
3707ec681f3Smrg   cbv_desc.BufferLocation = res ? res->GetGPUVirtualAddress() : 0;
3717ec681f3Smrg   cbv_desc.SizeInBytes = size;
3727ec681f3Smrg
3737ec681f3Smrg   dev->CreateConstantBufferView(&cbv_desc, cpu_handle);
3747ec681f3Smrg}
3757ec681f3Smrg
3767ec681f3SmrgComPtr<ID3D12Resource>
3777ec681f3SmrgComputeTest::add_uav_resource(ComputeTest::Resources &resources,
3787ec681f3Smrg                              unsigned spaceid, unsigned resid,
3797ec681f3Smrg                              const void *data, size_t num_elems,
3807ec681f3Smrg                              size_t elem_size)
3817ec681f3Smrg{
3827ec681f3Smrg   size_t size = align(elem_size * num_elems, 4);
3837ec681f3Smrg   D3D12_CPU_DESCRIPTOR_HANDLE handle;
3847ec681f3Smrg   ComPtr<ID3D12Resource> res;
3857ec681f3Smrg   handle = uav_heap->GetCPUDescriptorHandleForHeapStart();
3867ec681f3Smrg   handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
3877ec681f3Smrg
3887ec681f3Smrg   if (size) {
3897ec681f3Smrg      if (data)
3907ec681f3Smrg         res = create_buffer_with_data(data, size);
3917ec681f3Smrg      else
3927ec681f3Smrg         res = create_buffer(size, D3D12_HEAP_TYPE_DEFAULT);
3937ec681f3Smrg
3947ec681f3Smrg      resource_barrier(res, D3D12_RESOURCE_STATE_COMMON,
3957ec681f3Smrg                       D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
3967ec681f3Smrg   }
3977ec681f3Smrg   create_uav_buffer(res, num_elems, elem_size, handle);
3987ec681f3Smrg   resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_UAV, spaceid, resid);
3997ec681f3Smrg   return res;
4007ec681f3Smrg}
4017ec681f3Smrg
4027ec681f3SmrgComPtr<ID3D12Resource>
4037ec681f3SmrgComputeTest::add_cbv_resource(ComputeTest::Resources &resources,
4047ec681f3Smrg                              unsigned spaceid, unsigned resid,
4057ec681f3Smrg                              const void *data, size_t size)
4067ec681f3Smrg{
4077ec681f3Smrg   unsigned aligned_size = align(size, 256);
4087ec681f3Smrg   D3D12_CPU_DESCRIPTOR_HANDLE handle;
4097ec681f3Smrg   ComPtr<ID3D12Resource> res;
4107ec681f3Smrg   handle = uav_heap->GetCPUDescriptorHandleForHeapStart();
4117ec681f3Smrg   handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
4127ec681f3Smrg
4137ec681f3Smrg   if (size) {
4147ec681f3Smrg     assert(data);
4157ec681f3Smrg     res = create_sized_buffer_with_data(aligned_size, data, size);
4167ec681f3Smrg   }
4177ec681f3Smrg   create_cbv(res, aligned_size, handle);
4187ec681f3Smrg   resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_CBV, spaceid, resid);
4197ec681f3Smrg   return res;
4207ec681f3Smrg}
4217ec681f3Smrg
4227ec681f3Smrgvoid
4237ec681f3SmrgComputeTest::run_shader_with_raw_args(Shader shader,
4247ec681f3Smrg                                      const CompileArgs &compile_args,
4257ec681f3Smrg                                      const std::vector<RawShaderArg *> &args)
4267ec681f3Smrg{
4277ec681f3Smrg   if (args.size() < 1)
4287ec681f3Smrg      throw runtime_error("no inputs");
4297ec681f3Smrg
4307ec681f3Smrg   static HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
4317ec681f3Smrg   if (!hD3D12Mod)
4327ec681f3Smrg      throw runtime_error("Failed to load D3D12.DLL");
4337ec681f3Smrg
4347ec681f3Smrg   D3D12SerializeVersionedRootSignature = (PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE)GetProcAddress(hD3D12Mod, "D3D12SerializeVersionedRootSignature");
4357ec681f3Smrg
4367ec681f3Smrg   if (args.size() != shader.dxil->kernel->num_args)
4377ec681f3Smrg      throw runtime_error("incorrect number of inputs");
4387ec681f3Smrg
4397ec681f3Smrg   struct clc_runtime_kernel_conf conf = { 0 };
4407ec681f3Smrg
4417ec681f3Smrg   // Older WARP and some hardware doesn't support int64, so for these tests, unconditionally lower away int64
4427ec681f3Smrg   // A more complex runtime can be smarter about detecting when this needs to be done
4437ec681f3Smrg   conf.lower_bit_size = 64;
4447ec681f3Smrg
4457ec681f3Smrg   if (!shader.dxil->metadata.local_size[0])
4467ec681f3Smrg      conf.local_size[0] = compile_args.x;
4477ec681f3Smrg   else
4487ec681f3Smrg      conf.local_size[0] = shader.dxil->metadata.local_size[0];
4497ec681f3Smrg
4507ec681f3Smrg   if (!shader.dxil->metadata.local_size[1])
4517ec681f3Smrg      conf.local_size[1] = compile_args.y;
4527ec681f3Smrg   else
4537ec681f3Smrg      conf.local_size[1] = shader.dxil->metadata.local_size[1];
4547ec681f3Smrg
4557ec681f3Smrg   if (!shader.dxil->metadata.local_size[2])
4567ec681f3Smrg      conf.local_size[2] = compile_args.z;
4577ec681f3Smrg   else
4587ec681f3Smrg      conf.local_size[2] = shader.dxil->metadata.local_size[2];
4597ec681f3Smrg
4607ec681f3Smrg   if (compile_args.x % conf.local_size[0] ||
4617ec681f3Smrg       compile_args.y % conf.local_size[1] ||
4627ec681f3Smrg       compile_args.z % conf.local_size[2])
4637ec681f3Smrg      throw runtime_error("invalid global size must be a multiple of local size");
4647ec681f3Smrg
4657ec681f3Smrg   std::vector<struct clc_runtime_arg_info> argsinfo(args.size());
4667ec681f3Smrg
4677ec681f3Smrg   conf.args = argsinfo.data();
4687ec681f3Smrg   conf.support_global_work_id_offsets =
4697ec681f3Smrg      compile_args.work_props.global_offset_x != 0 ||
4707ec681f3Smrg      compile_args.work_props.global_offset_y != 0 ||
4717ec681f3Smrg      compile_args.work_props.global_offset_z != 0;
4727ec681f3Smrg   conf.support_workgroup_id_offsets =
4737ec681f3Smrg      compile_args.work_props.group_id_offset_x != 0 ||
4747ec681f3Smrg      compile_args.work_props.group_id_offset_y != 0 ||
4757ec681f3Smrg      compile_args.work_props.group_id_offset_z != 0;
4767ec681f3Smrg
4777ec681f3Smrg   for (unsigned i = 0; i < shader.dxil->kernel->num_args; ++i) {
4787ec681f3Smrg      RawShaderArg *arg = args[i];
4797ec681f3Smrg      size_t size = arg->get_elem_size() * arg->get_num_elems();
4807ec681f3Smrg
4817ec681f3Smrg      switch (shader.dxil->kernel->args[i].address_qualifier) {
4827ec681f3Smrg      case CLC_KERNEL_ARG_ADDRESS_LOCAL:
4837ec681f3Smrg         argsinfo[i].localptr.size = size;
4847ec681f3Smrg         break;
4857ec681f3Smrg      default:
4867ec681f3Smrg         break;
4877ec681f3Smrg      }
4887ec681f3Smrg   }
4897ec681f3Smrg
4907ec681f3Smrg   configure(shader, &conf);
4917ec681f3Smrg   validate(shader);
4927ec681f3Smrg
4937ec681f3Smrg   std::shared_ptr<struct clc_dxil_object> &dxil = shader.dxil;
4947ec681f3Smrg
4957ec681f3Smrg   std::vector<uint8_t> argsbuf(dxil->metadata.kernel_inputs_buf_size);
4967ec681f3Smrg   std::vector<ComPtr<ID3D12Resource>> argres(shader.dxil->kernel->num_args);
4977ec681f3Smrg   clc_work_properties_data work_props = compile_args.work_props;
4987ec681f3Smrg   if (!conf.support_workgroup_id_offsets) {
4997ec681f3Smrg      work_props.group_count_total_x = compile_args.x / conf.local_size[0];
5007ec681f3Smrg      work_props.group_count_total_y = compile_args.y / conf.local_size[1];
5017ec681f3Smrg      work_props.group_count_total_z = compile_args.z / conf.local_size[2];
5027ec681f3Smrg   }
5037ec681f3Smrg   if (work_props.work_dim == 0)
5047ec681f3Smrg      work_props.work_dim = 3;
5057ec681f3Smrg   Resources resources;
5067ec681f3Smrg
5077ec681f3Smrg   for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
5087ec681f3Smrg      RawShaderArg *arg = args[i];
5097ec681f3Smrg      size_t size = arg->get_elem_size() * arg->get_num_elems();
5107ec681f3Smrg      void *slot = argsbuf.data() + dxil->metadata.args[i].offset;
5117ec681f3Smrg
5127ec681f3Smrg      switch (dxil->kernel->args[i].address_qualifier) {
5137ec681f3Smrg      case CLC_KERNEL_ARG_ADDRESS_CONSTANT:
5147ec681f3Smrg      case CLC_KERNEL_ARG_ADDRESS_GLOBAL: {
5157ec681f3Smrg         assert(dxil->metadata.args[i].size == sizeof(uint64_t));
5167ec681f3Smrg         uint64_t *ptr_slot = (uint64_t *)slot;
5177ec681f3Smrg         if (arg->get_data())
5187ec681f3Smrg            *ptr_slot = (uint64_t)dxil->metadata.args[i].globconstptr.buf_id << 32;
5197ec681f3Smrg         else
5207ec681f3Smrg            *ptr_slot = ~0ull;
5217ec681f3Smrg         break;
5227ec681f3Smrg      }
5237ec681f3Smrg      case CLC_KERNEL_ARG_ADDRESS_LOCAL: {
5247ec681f3Smrg         assert(dxil->metadata.args[i].size == sizeof(uint64_t));
5257ec681f3Smrg         uint64_t *ptr_slot = (uint64_t *)slot;
5267ec681f3Smrg         *ptr_slot = dxil->metadata.args[i].localptr.sharedmem_offset;
5277ec681f3Smrg         break;
5287ec681f3Smrg      }
5297ec681f3Smrg      case CLC_KERNEL_ARG_ADDRESS_PRIVATE: {
5307ec681f3Smrg         assert(size == dxil->metadata.args[i].size);
5317ec681f3Smrg         memcpy(slot, arg->get_data(), size);
5327ec681f3Smrg         break;
5337ec681f3Smrg      }
5347ec681f3Smrg      default:
5357ec681f3Smrg         assert(0);
5367ec681f3Smrg      }
5377ec681f3Smrg   }
5387ec681f3Smrg
5397ec681f3Smrg   for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
5407ec681f3Smrg      RawShaderArg *arg = args[i];
5417ec681f3Smrg
5427ec681f3Smrg      if (dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL ||
5437ec681f3Smrg          dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT) {
5447ec681f3Smrg         argres[i] = add_uav_resource(resources, 0,
5457ec681f3Smrg                                      dxil->metadata.args[i].globconstptr.buf_id,
5467ec681f3Smrg                                      arg->get_data(), arg->get_num_elems(),
5477ec681f3Smrg                                      arg->get_elem_size());
5487ec681f3Smrg      }
5497ec681f3Smrg   }
5507ec681f3Smrg
5517ec681f3Smrg   if (dxil->metadata.printf.uav_id > 0)
5527ec681f3Smrg      add_uav_resource(resources, 0, dxil->metadata.printf.uav_id, NULL, 1024 * 1024 / 4, 4);
5537ec681f3Smrg
5547ec681f3Smrg   for (unsigned i = 0; i < dxil->metadata.num_consts; ++i)
5557ec681f3Smrg      add_uav_resource(resources, 0, dxil->metadata.consts[i].uav_id,
5567ec681f3Smrg                       dxil->metadata.consts[i].data,
5577ec681f3Smrg                       dxil->metadata.consts[i].size / 4, 4);
5587ec681f3Smrg
5597ec681f3Smrg   if (argsbuf.size())
5607ec681f3Smrg      add_cbv_resource(resources, 0, dxil->metadata.kernel_inputs_cbv_id,
5617ec681f3Smrg                       argsbuf.data(), argsbuf.size());
5627ec681f3Smrg
5637ec681f3Smrg   add_cbv_resource(resources, 0, dxil->metadata.work_properties_cbv_id,
5647ec681f3Smrg                    &work_props, sizeof(work_props));
5657ec681f3Smrg
5667ec681f3Smrg   auto root_sig = create_root_signature(resources);
5677ec681f3Smrg   auto pipeline_state = create_pipeline_state(root_sig, *dxil);
5687ec681f3Smrg
5697ec681f3Smrg   cmdlist->SetDescriptorHeaps(1, &uav_heap);
5707ec681f3Smrg   cmdlist->SetComputeRootSignature(root_sig.Get());
5717ec681f3Smrg   cmdlist->SetComputeRootDescriptorTable(0, uav_heap->GetGPUDescriptorHandleForHeapStart());
5727ec681f3Smrg   cmdlist->SetPipelineState(pipeline_state.Get());
5737ec681f3Smrg
5747ec681f3Smrg   cmdlist->Dispatch(compile_args.x / conf.local_size[0],
5757ec681f3Smrg                     compile_args.y / conf.local_size[1],
5767ec681f3Smrg                     compile_args.z / conf.local_size[2]);
5777ec681f3Smrg
5787ec681f3Smrg   for (auto &range : resources.ranges) {
5797ec681f3Smrg      if (range.RangeType == D3D12_DESCRIPTOR_RANGE_TYPE_UAV) {
5807ec681f3Smrg         for (unsigned i = range.OffsetInDescriptorsFromTableStart;
5817ec681f3Smrg              i < range.NumDescriptors; i++) {
5827ec681f3Smrg            if (!resources.descs[i].Get())
5837ec681f3Smrg               continue;
5847ec681f3Smrg
5857ec681f3Smrg            resource_barrier(resources.descs[i],
5867ec681f3Smrg                             D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
5877ec681f3Smrg                             D3D12_RESOURCE_STATE_COMMON);
5887ec681f3Smrg         }
5897ec681f3Smrg      }
5907ec681f3Smrg   }
5917ec681f3Smrg
5927ec681f3Smrg   execute_cmdlist();
5937ec681f3Smrg
5947ec681f3Smrg   for (unsigned i = 0; i < args.size(); i++) {
5957ec681f3Smrg      if (!(args[i]->get_direction() & SHADER_ARG_OUTPUT))
5967ec681f3Smrg         continue;
5977ec681f3Smrg
5987ec681f3Smrg      assert(dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL);
5997ec681f3Smrg      get_buffer_data(argres[i], args[i]->get_data(),
6007ec681f3Smrg                      args[i]->get_elem_size() * args[i]->get_num_elems());
6017ec681f3Smrg   }
6027ec681f3Smrg
6037ec681f3Smrg   ComPtr<ID3D12InfoQueue> info_queue;
6047ec681f3Smrg   dev->QueryInterface(info_queue.ReleaseAndGetAddressOf());
6057ec681f3Smrg   if (info_queue)
6067ec681f3Smrg   {
6077ec681f3Smrg      EXPECT_EQ(0, info_queue->GetNumStoredMessages());
6087ec681f3Smrg      for (unsigned i = 0; i < info_queue->GetNumStoredMessages(); ++i) {
6097ec681f3Smrg         SIZE_T message_size = 0;
6107ec681f3Smrg         info_queue->GetMessageA(i, nullptr, &message_size);
6117ec681f3Smrg         D3D12_MESSAGE* message = (D3D12_MESSAGE*)malloc(message_size);
6127ec681f3Smrg         info_queue->GetMessageA(i, message, &message_size);
6137ec681f3Smrg         FAIL() << message->pDescription;
6147ec681f3Smrg         free(message);
6157ec681f3Smrg      }
6167ec681f3Smrg   }
6177ec681f3Smrg}
6187ec681f3Smrg
6197ec681f3Smrgvoid
6207ec681f3SmrgComputeTest::SetUp()
6217ec681f3Smrg{
6227ec681f3Smrg   static struct clc_libclc *compiler_ctx_g = nullptr;
6237ec681f3Smrg
6247ec681f3Smrg   if (!compiler_ctx_g) {
6257ec681f3Smrg      clc_libclc_dxil_options options = { };
6267ec681f3Smrg      options.optimize = (debug_get_option_debug_compute() & COMPUTE_DEBUG_OPTIMIZE_LIBCLC) != 0;
6277ec681f3Smrg
6287ec681f3Smrg      compiler_ctx_g = clc_libclc_new_dxil(&logger, &options);
6297ec681f3Smrg      if (!compiler_ctx_g)
6307ec681f3Smrg         throw runtime_error("failed to create CLC compiler context");
6317ec681f3Smrg
6327ec681f3Smrg      if (debug_get_option_debug_compute() & COMPUTE_DEBUG_SERIALIZE_LIBCLC) {
6337ec681f3Smrg         void *serialized = nullptr;
6347ec681f3Smrg         size_t serialized_size = 0;
6357ec681f3Smrg         clc_libclc_serialize(compiler_ctx_g, &serialized, &serialized_size);
6367ec681f3Smrg         if (!serialized)
6377ec681f3Smrg            throw runtime_error("failed to serialize CLC compiler context");
6387ec681f3Smrg
6397ec681f3Smrg         clc_free_libclc(compiler_ctx_g);
6407ec681f3Smrg         compiler_ctx_g = nullptr;
6417ec681f3Smrg
6427ec681f3Smrg         compiler_ctx_g = clc_libclc_deserialize(serialized, serialized_size);
6437ec681f3Smrg         if (!compiler_ctx_g)
6447ec681f3Smrg            throw runtime_error("failed to deserialize CLC compiler context");
6457ec681f3Smrg
6467ec681f3Smrg         clc_libclc_free_serialized(serialized);
6477ec681f3Smrg      }
6487ec681f3Smrg   }
6497ec681f3Smrg   compiler_ctx = compiler_ctx_g;
6507ec681f3Smrg
6517ec681f3Smrg   enable_d3d12_debug_layer();
6527ec681f3Smrg
6537ec681f3Smrg   factory = get_dxgi_factory();
6547ec681f3Smrg   if (!factory)
6557ec681f3Smrg      throw runtime_error("failed to create DXGI factory");
6567ec681f3Smrg
6577ec681f3Smrg   adapter = choose_adapter(factory);
6587ec681f3Smrg   if (!adapter)
6597ec681f3Smrg      throw runtime_error("failed to choose adapter");
6607ec681f3Smrg
6617ec681f3Smrg   dev = create_device(adapter);
6627ec681f3Smrg   if (!dev)
6637ec681f3Smrg      throw runtime_error("failed to create device");
6647ec681f3Smrg
6657ec681f3Smrg   if (FAILED(dev->CreateFence(0, D3D12_FENCE_FLAG_NONE,
6667ec681f3Smrg                               __uuidof(cmdqueue_fence),
6677ec681f3Smrg                               (void **)&cmdqueue_fence)))
6687ec681f3Smrg      throw runtime_error("failed to create fence\n");
6697ec681f3Smrg
6707ec681f3Smrg   D3D12_COMMAND_QUEUE_DESC queue_desc;
6717ec681f3Smrg   queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
6727ec681f3Smrg   queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL;
6737ec681f3Smrg   queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
6747ec681f3Smrg   queue_desc.NodeMask = 0;
6757ec681f3Smrg   if (FAILED(dev->CreateCommandQueue(&queue_desc,
6767ec681f3Smrg                                      __uuidof(cmdqueue),
6777ec681f3Smrg                                      (void **)&cmdqueue)))
6787ec681f3Smrg      throw runtime_error("failed to create command queue");
6797ec681f3Smrg
6807ec681f3Smrg   if (FAILED(dev->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE,
6817ec681f3Smrg             __uuidof(cmdalloc), (void **)&cmdalloc)))
6827ec681f3Smrg      throw runtime_error("failed to create command allocator");
6837ec681f3Smrg
6847ec681f3Smrg   if (FAILED(dev->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE,
6857ec681f3Smrg             cmdalloc, NULL, __uuidof(cmdlist), (void **)&cmdlist)))
6867ec681f3Smrg      throw runtime_error("failed to create command list");
6877ec681f3Smrg
6887ec681f3Smrg   D3D12_DESCRIPTOR_HEAP_DESC heap_desc;
6897ec681f3Smrg   heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
6907ec681f3Smrg   heap_desc.NumDescriptors = 1000;
6917ec681f3Smrg   heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
6927ec681f3Smrg   heap_desc.NodeMask = 0;
6937ec681f3Smrg   if (FAILED(dev->CreateDescriptorHeap(&heap_desc,
6947ec681f3Smrg       __uuidof(uav_heap), (void **)&uav_heap)))
6957ec681f3Smrg      throw runtime_error("failed to create descriptor heap");
6967ec681f3Smrg
6977ec681f3Smrg   uav_heap_incr = dev->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
6987ec681f3Smrg
6997ec681f3Smrg   event = CreateEvent(NULL, FALSE, FALSE, NULL);
7007ec681f3Smrg   if (!event)
7017ec681f3Smrg      throw runtime_error("Failed to create event");
7027ec681f3Smrg   fence_value = 1;
7037ec681f3Smrg}
7047ec681f3Smrg
7057ec681f3Smrgvoid
7067ec681f3SmrgComputeTest::TearDown()
7077ec681f3Smrg{
7087ec681f3Smrg   CloseHandle(event);
7097ec681f3Smrg
7107ec681f3Smrg   uav_heap->Release();
7117ec681f3Smrg   cmdlist->Release();
7127ec681f3Smrg   cmdalloc->Release();
7137ec681f3Smrg   cmdqueue->Release();
7147ec681f3Smrg   cmdqueue_fence->Release();
7157ec681f3Smrg   dev->Release();
7167ec681f3Smrg   adapter->Release();
7177ec681f3Smrg   factory->Release();
7187ec681f3Smrg}
7197ec681f3Smrg
7207ec681f3SmrgPFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE ComputeTest::D3D12SerializeVersionedRootSignature;
7217ec681f3Smrg
7227ec681f3Smrgbool
7237ec681f3Smrgvalidate_module(const struct clc_dxil_object &dxil)
7247ec681f3Smrg{
7257ec681f3Smrg   static HMODULE hmod = LoadLibrary("DXIL.DLL");
7267ec681f3Smrg   if (!hmod) {
7277ec681f3Smrg      /* Enabling experimental shaders allows us to run unsigned shader code,
7287ec681f3Smrg       * such as when under the debugger where we can't run the validator. */
7297ec681f3Smrg      if (debug_get_option_debug_compute() & COMPUTE_DEBUG_EXPERIMENTAL_SHADERS)
7307ec681f3Smrg         return true;
7317ec681f3Smrg      else
7327ec681f3Smrg         throw runtime_error("failed to load DXIL.DLL");
7337ec681f3Smrg   }
7347ec681f3Smrg
7357ec681f3Smrg   DxcCreateInstanceProc pfnDxcCreateInstance =
7367ec681f3Smrg      (DxcCreateInstanceProc)GetProcAddress(hmod, "DxcCreateInstance");
7377ec681f3Smrg   if (!pfnDxcCreateInstance)
7387ec681f3Smrg      throw runtime_error("failed to load DxcCreateInstance");
7397ec681f3Smrg
7407ec681f3Smrg   struct shader_blob : public IDxcBlob {
7417ec681f3Smrg      shader_blob(void *data, size_t size) : data(data), size(size) {}
7427ec681f3Smrg      LPVOID STDMETHODCALLTYPE GetBufferPointer() override { return data; }
7437ec681f3Smrg      SIZE_T STDMETHODCALLTYPE GetBufferSize() override { return size; }
7447ec681f3Smrg      HRESULT STDMETHODCALLTYPE QueryInterface(REFIID, void **) override { return E_NOINTERFACE; }
7457ec681f3Smrg      ULONG STDMETHODCALLTYPE AddRef() override { return 1; }
7467ec681f3Smrg      ULONG STDMETHODCALLTYPE Release() override { return 0; }
7477ec681f3Smrg      void *data;
7487ec681f3Smrg      size_t size;
7497ec681f3Smrg   } blob(dxil.binary.data, dxil.binary.size);
7507ec681f3Smrg
7517ec681f3Smrg   IDxcValidator *validator;
7527ec681f3Smrg   if (FAILED(pfnDxcCreateInstance(CLSID_DxcValidator, __uuidof(IDxcValidator),
7537ec681f3Smrg                                   (void **)&validator)))
7547ec681f3Smrg      throw runtime_error("failed to create IDxcValidator");
7557ec681f3Smrg
7567ec681f3Smrg   IDxcOperationResult *result;
7577ec681f3Smrg   if (FAILED(validator->Validate(&blob, DxcValidatorFlags_InPlaceEdit,
7587ec681f3Smrg                                  &result)))
7597ec681f3Smrg      throw runtime_error("Validate failed");
7607ec681f3Smrg
7617ec681f3Smrg   HRESULT hr;
7627ec681f3Smrg   if (FAILED(result->GetStatus(&hr)) ||
7637ec681f3Smrg       FAILED(hr)) {
7647ec681f3Smrg      IDxcBlobEncoding *message;
7657ec681f3Smrg      result->GetErrorBuffer(&message);
7667ec681f3Smrg      fprintf(stderr, "D3D12: validation failed: %*s\n",
7677ec681f3Smrg                   (int)message->GetBufferSize(),
7687ec681f3Smrg                   (char *)message->GetBufferPointer());
7697ec681f3Smrg      message->Release();
7707ec681f3Smrg      validator->Release();
7717ec681f3Smrg      result->Release();
7727ec681f3Smrg      return false;
7737ec681f3Smrg   }
7747ec681f3Smrg
7757ec681f3Smrg   validator->Release();
7767ec681f3Smrg   result->Release();
7777ec681f3Smrg   return true;
7787ec681f3Smrg}
7797ec681f3Smrg
7807ec681f3Smrgstatic void
7817ec681f3Smrgdump_blob(const char *path, const struct clc_dxil_object &dxil)
7827ec681f3Smrg{
7837ec681f3Smrg   FILE *fp = fopen(path, "wb");
7847ec681f3Smrg   if (fp) {
7857ec681f3Smrg      fwrite(dxil.binary.data, 1, dxil.binary.size, fp);
7867ec681f3Smrg      fclose(fp);
7877ec681f3Smrg      printf("D3D12: wrote '%s'...\n", path);
7887ec681f3Smrg   }
7897ec681f3Smrg}
7907ec681f3Smrg
7917ec681f3SmrgComputeTest::Shader
7927ec681f3SmrgComputeTest::compile(const std::vector<const char *> &sources,
7937ec681f3Smrg                     const std::vector<const char *> &compile_args,
7947ec681f3Smrg                     bool create_library)
7957ec681f3Smrg{
7967ec681f3Smrg   struct clc_compile_args args = {
7977ec681f3Smrg   };
7987ec681f3Smrg   args.args = compile_args.data();
7997ec681f3Smrg   args.num_args = (unsigned)compile_args.size();
8007ec681f3Smrg   ComputeTest::Shader shader;
8017ec681f3Smrg
8027ec681f3Smrg   std::vector<Shader> shaders;
8037ec681f3Smrg
8047ec681f3Smrg   args.source.name = "obj.cl";
8057ec681f3Smrg
8067ec681f3Smrg   for (unsigned i = 0; i < sources.size(); i++) {
8077ec681f3Smrg      args.source.value = sources[i];
8087ec681f3Smrg
8097ec681f3Smrg      clc_binary spirv{};
8107ec681f3Smrg      if (!clc_compile_c_to_spirv(&args, &logger, &spirv))
8117ec681f3Smrg         throw runtime_error("failed to compile object!");
8127ec681f3Smrg
8137ec681f3Smrg      Shader shader;
8147ec681f3Smrg      shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
8157ec681f3Smrg         {
8167ec681f3Smrg            clc_free_spirv(spirv);
8177ec681f3Smrg            delete spirv;
8187ec681f3Smrg         });
8197ec681f3Smrg      shaders.push_back(shader);
8207ec681f3Smrg   }
8217ec681f3Smrg
8227ec681f3Smrg   if (shaders.size() == 1 && create_library)
8237ec681f3Smrg      return shaders[0];
8247ec681f3Smrg
8257ec681f3Smrg   return link(shaders, create_library);
8267ec681f3Smrg}
8277ec681f3Smrg
8287ec681f3SmrgComputeTest::Shader
8297ec681f3SmrgComputeTest::link(const std::vector<Shader> &sources,
8307ec681f3Smrg                  bool create_library)
8317ec681f3Smrg{
8327ec681f3Smrg   std::vector<const clc_binary*> objs;
8337ec681f3Smrg   for (auto& source : sources)
8347ec681f3Smrg      objs.push_back(&*source.obj);
8357ec681f3Smrg
8367ec681f3Smrg   struct clc_linker_args link_args = {};
8377ec681f3Smrg   link_args.in_objs = objs.data();
8387ec681f3Smrg   link_args.num_in_objs = (unsigned)objs.size();
8397ec681f3Smrg   link_args.create_library = create_library;
8407ec681f3Smrg   clc_binary spirv{};
8417ec681f3Smrg   if (!clc_link_spirv(&link_args, &logger, &spirv))
8427ec681f3Smrg      throw runtime_error("failed to link objects!");
8437ec681f3Smrg
8447ec681f3Smrg   ComputeTest::Shader shader;
8457ec681f3Smrg   shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
8467ec681f3Smrg      {
8477ec681f3Smrg         clc_free_spirv(spirv);
8487ec681f3Smrg         delete spirv;
8497ec681f3Smrg      });
8507ec681f3Smrg   if (!link_args.create_library)
8517ec681f3Smrg      configure(shader, NULL);
8527ec681f3Smrg
8537ec681f3Smrg   return shader;
8547ec681f3Smrg}
8557ec681f3Smrg
8567ec681f3SmrgComputeTest::Shader
8577ec681f3SmrgComputeTest::assemble(const char *source)
8587ec681f3Smrg{
8597ec681f3Smrg   spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
8607ec681f3Smrg   std::vector<uint32_t> binary;
8617ec681f3Smrg   if (!tools.Assemble(source, strlen(source), &binary))
8627ec681f3Smrg      throw runtime_error("failed to assemble");
8637ec681f3Smrg
8647ec681f3Smrg   ComputeTest::Shader shader;
8657ec681f3Smrg   shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary *spirv)
8667ec681f3Smrg      {
8677ec681f3Smrg         free(spirv->data);
8687ec681f3Smrg         delete spirv;
8697ec681f3Smrg      });
8707ec681f3Smrg   shader.obj->size = binary.size() * 4;
8717ec681f3Smrg   shader.obj->data = malloc(shader.obj->size);
8727ec681f3Smrg   memcpy(shader.obj->data, binary.data(), shader.obj->size);
8737ec681f3Smrg
8747ec681f3Smrg   configure(shader, NULL);
8757ec681f3Smrg
8767ec681f3Smrg   return shader;
8777ec681f3Smrg}
8787ec681f3Smrg
8797ec681f3Smrgvoid
8807ec681f3SmrgComputeTest::configure(Shader &shader,
8817ec681f3Smrg                       const struct clc_runtime_kernel_conf *conf)
8827ec681f3Smrg{
8837ec681f3Smrg   if (!shader.metadata) {
8847ec681f3Smrg      shader.metadata = std::shared_ptr<clc_parsed_spirv>(new clc_parsed_spirv{}, [](clc_parsed_spirv *metadata)
8857ec681f3Smrg         {
8867ec681f3Smrg            clc_free_parsed_spirv(metadata);
8877ec681f3Smrg            delete metadata;
8887ec681f3Smrg         });
8897ec681f3Smrg      if (!clc_parse_spirv(shader.obj.get(), NULL, shader.metadata.get()))
8907ec681f3Smrg         throw runtime_error("failed to parse spirv!");
8917ec681f3Smrg   }
8927ec681f3Smrg
8937ec681f3Smrg   shader.dxil = std::shared_ptr<clc_dxil_object>(new clc_dxil_object{}, [](clc_dxil_object *dxil)
8947ec681f3Smrg      {
8957ec681f3Smrg         clc_free_dxil_object(dxil);
8967ec681f3Smrg         delete dxil;
8977ec681f3Smrg      });
8987ec681f3Smrg   if (!clc_spirv_to_dxil(compiler_ctx, shader.obj.get(), shader.metadata.get(), "main_test", conf, nullptr, &logger, shader.dxil.get()))
8997ec681f3Smrg      throw runtime_error("failed to compile kernel!");
9007ec681f3Smrg}
9017ec681f3Smrg
9027ec681f3Smrgvoid
9037ec681f3SmrgComputeTest::validate(ComputeTest::Shader &shader)
9047ec681f3Smrg{
9057ec681f3Smrg   dump_blob("unsigned.cso", *shader.dxil);
9067ec681f3Smrg   if (!validate_module(*shader.dxil))
9077ec681f3Smrg      throw runtime_error("failed to validate module!");
9087ec681f3Smrg
9097ec681f3Smrg   dump_blob("signed.cso", *shader.dxil);
9107ec681f3Smrg}
911