1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include <stdio.h>
25#include <stdint.h>
26#include <stdexcept>
27
28#include <directx/d3d12.h>
29#include <dxgi1_4.h>
30#include <gtest/gtest.h>
31#include <wrl.h>
32
33#include "util/u_debug.h"
34#include "clc_compiler.h"
35#include "compute_test.h"
36#include "dxcapi.h"
37
38#include <spirv-tools/libspirv.hpp>
39
40using std::runtime_error;
41using Microsoft::WRL::ComPtr;
42
43enum compute_test_debug_flags {
44   COMPUTE_DEBUG_EXPERIMENTAL_SHADERS = 1 << 0,
45   COMPUTE_DEBUG_USE_HW_D3D           = 1 << 1,
46   COMPUTE_DEBUG_OPTIMIZE_LIBCLC      = 1 << 2,
47   COMPUTE_DEBUG_SERIALIZE_LIBCLC     = 1 << 3,
48};
49
50static const struct debug_named_value compute_debug_options[] = {
51   { "experimental_shaders",  COMPUTE_DEBUG_EXPERIMENTAL_SHADERS, "Enable experimental shaders" },
52   { "use_hw_d3d",            COMPUTE_DEBUG_USE_HW_D3D,           "Use a hardware D3D device"   },
53   { "optimize_libclc",       COMPUTE_DEBUG_OPTIMIZE_LIBCLC,      "Optimize the clc_libclc before using it" },
54   { "serialize_libclc",      COMPUTE_DEBUG_SERIALIZE_LIBCLC,     "Serialize and deserialize the clc_libclc" },
55   DEBUG_NAMED_VALUE_END
56};
57
58DEBUG_GET_ONCE_FLAGS_OPTION(debug_compute, "COMPUTE_TEST_DEBUG", compute_debug_options, 0)
59
60static void warning_callback(void *priv, const char *msg)
61{
62   fprintf(stderr, "WARNING: %s\n", msg);
63}
64
65static void error_callback(void *priv, const char *msg)
66{
67   fprintf(stderr, "ERROR: %s\n", msg);
68}
69
70static const struct clc_logger logger = {
71   NULL,
72   error_callback,
73   warning_callback,
74};
75
76void
77ComputeTest::enable_d3d12_debug_layer()
78{
79   HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
80   if (!hD3D12Mod) {
81      fprintf(stderr, "D3D12: failed to load D3D12.DLL\n");
82      return;
83   }
84
85   typedef HRESULT(WINAPI * PFN_D3D12_GET_DEBUG_INTERFACE)(REFIID riid,
86                                                           void **ppFactory);
87   PFN_D3D12_GET_DEBUG_INTERFACE D3D12GetDebugInterface = (PFN_D3D12_GET_DEBUG_INTERFACE)GetProcAddress(hD3D12Mod, "D3D12GetDebugInterface");
88   if (!D3D12GetDebugInterface) {
89      fprintf(stderr, "D3D12: failed to load D3D12GetDebugInterface from D3D12.DLL\n");
90      return;
91   }
92
93   ID3D12Debug *debug;
94   if (FAILED(D3D12GetDebugInterface(__uuidof(ID3D12Debug), (void **)& debug))) {
95      fprintf(stderr, "D3D12: D3D12GetDebugInterface failed\n");
96      return;
97   }
98
99   debug->EnableDebugLayer();
100}
101
102IDXGIFactory4 *
103ComputeTest::get_dxgi_factory()
104{
105   static const GUID IID_IDXGIFactory4 = {
106      0x1bc6ea02, 0xef36, 0x464f,
107      { 0xbf, 0x0c, 0x21, 0xca, 0x39, 0xe5, 0x16, 0x8a }
108   };
109
110   typedef HRESULT(WINAPI * PFN_CREATE_DXGI_FACTORY)(REFIID riid,
111                                                     void **ppFactory);
112   PFN_CREATE_DXGI_FACTORY CreateDXGIFactory;
113
114   HMODULE hDXGIMod = LoadLibrary("DXGI.DLL");
115   if (!hDXGIMod)
116      throw runtime_error("Failed to load DXGI.DLL");
117
118   CreateDXGIFactory = (PFN_CREATE_DXGI_FACTORY)GetProcAddress(hDXGIMod, "CreateDXGIFactory");
119   if (!CreateDXGIFactory)
120      throw runtime_error("Failed to load CreateDXGIFactory from DXGI.DLL");
121
122   IDXGIFactory4 *factory = NULL;
123   HRESULT hr = CreateDXGIFactory(IID_IDXGIFactory4, (void **)&factory);
124   if (FAILED(hr))
125      throw runtime_error("CreateDXGIFactory failed");
126
127   return factory;
128}
129
130IDXGIAdapter1 *
131ComputeTest::choose_adapter(IDXGIFactory4 *factory)
132{
133   IDXGIAdapter1 *ret;
134
135   if (debug_get_option_debug_compute() & COMPUTE_DEBUG_USE_HW_D3D) {
136      for (unsigned i = 0; SUCCEEDED(factory->EnumAdapters1(i, &ret)); i++) {
137         DXGI_ADAPTER_DESC1 desc;
138         ret->GetDesc1(&desc);
139         if (!(desc.Flags & D3D_DRIVER_TYPE_SOFTWARE))
140            return ret;
141      }
142      throw runtime_error("Failed to enum hardware adapter");
143   } else {
144      if (FAILED(factory->EnumWarpAdapter(__uuidof(IDXGIAdapter1),
145         (void **)& ret)))
146         throw runtime_error("Failed to enum warp adapter");
147      return ret;
148   }
149}
150
151ID3D12Device *
152ComputeTest::create_device(IDXGIAdapter1 *adapter)
153{
154   typedef HRESULT(WINAPI *PFN_D3D12CREATEDEVICE)(IUnknown *, D3D_FEATURE_LEVEL, REFIID, void **);
155   PFN_D3D12CREATEDEVICE D3D12CreateDevice;
156
157   HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
158   if (!hD3D12Mod)
159      throw runtime_error("failed to load D3D12.DLL");
160
161   if (debug_get_option_debug_compute() & COMPUTE_DEBUG_EXPERIMENTAL_SHADERS) {
162      typedef HRESULT(WINAPI *PFN_D3D12ENABLEEXPERIMENTALFEATURES)(UINT, const IID *, void *, UINT *);
163      PFN_D3D12ENABLEEXPERIMENTALFEATURES D3D12EnableExperimentalFeatures;
164      D3D12EnableExperimentalFeatures = (PFN_D3D12ENABLEEXPERIMENTALFEATURES)
165         GetProcAddress(hD3D12Mod, "D3D12EnableExperimentalFeatures");
166      if (FAILED(D3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModels, NULL, NULL)))
167         throw runtime_error("failed to enable experimental shader models");
168   }
169
170   D3D12CreateDevice = (PFN_D3D12CREATEDEVICE)GetProcAddress(hD3D12Mod, "D3D12CreateDevice");
171   if (!D3D12CreateDevice)
172      throw runtime_error("failed to load D3D12CreateDevice from D3D12.DLL");
173
174   ID3D12Device *dev;
175   if (FAILED(D3D12CreateDevice(adapter, D3D_FEATURE_LEVEL_12_0,
176       __uuidof(ID3D12Device), (void **)& dev)))
177      throw runtime_error("D3D12CreateDevice failed");
178
179   return dev;
180}
181
182ComPtr<ID3D12RootSignature>
183ComputeTest::create_root_signature(const ComputeTest::Resources &resources)
184{
185   D3D12_ROOT_PARAMETER1 root_param;
186   root_param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
187   root_param.DescriptorTable.NumDescriptorRanges = resources.ranges.size();
188   root_param.DescriptorTable.pDescriptorRanges = resources.ranges.data();
189   root_param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
190
191   D3D12_ROOT_SIGNATURE_DESC1 root_sig_desc;
192   root_sig_desc.NumParameters = 1;
193   root_sig_desc.pParameters = &root_param;
194   root_sig_desc.NumStaticSamplers = 0;
195   root_sig_desc.pStaticSamplers = NULL;
196   root_sig_desc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
197
198   D3D12_VERSIONED_ROOT_SIGNATURE_DESC versioned_desc;
199   versioned_desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
200   versioned_desc.Desc_1_1 = root_sig_desc;
201
202   ID3DBlob *sig, *error;
203   if (FAILED(D3D12SerializeVersionedRootSignature(&versioned_desc,
204       &sig, &error)))
205      throw runtime_error("D3D12SerializeVersionedRootSignature failed");
206
207   ComPtr<ID3D12RootSignature> ret;
208   if (FAILED(dev->CreateRootSignature(0,
209       sig->GetBufferPointer(),
210       sig->GetBufferSize(),
211       __uuidof(ret),
212       (void **)& ret)))
213      throw runtime_error("CreateRootSignature failed");
214
215   return ret;
216}
217
218ComPtr<ID3D12PipelineState>
219ComputeTest::create_pipeline_state(ComPtr<ID3D12RootSignature> &root_sig,
220                                   const struct clc_dxil_object &dxil)
221{
222   D3D12_COMPUTE_PIPELINE_STATE_DESC pipeline_desc = { root_sig.Get() };
223   pipeline_desc.CS.pShaderBytecode = dxil.binary.data;
224   pipeline_desc.CS.BytecodeLength = dxil.binary.size;
225
226   ComPtr<ID3D12PipelineState> pipeline_state;
227   if (FAILED(dev->CreateComputePipelineState(&pipeline_desc,
228                                              __uuidof(pipeline_state),
229                                              (void **)& pipeline_state)))
230      throw runtime_error("Failed to create pipeline state");
231   return pipeline_state;
232}
233
234ComPtr<ID3D12Resource>
235ComputeTest::create_buffer(int size, D3D12_HEAP_TYPE heap_type)
236{
237   D3D12_RESOURCE_DESC desc;
238   desc.Format = DXGI_FORMAT_UNKNOWN;
239   desc.Alignment = D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT;
240   desc.Width = size;
241   desc.Height = 1;
242   desc.DepthOrArraySize = 1;
243   desc.MipLevels = 1;
244   desc.SampleDesc.Count = 1;
245   desc.SampleDesc.Quality = 0;
246   desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
247   desc.Flags = heap_type == D3D12_HEAP_TYPE_DEFAULT ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS : D3D12_RESOURCE_FLAG_NONE;
248   desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
249
250   D3D12_HEAP_PROPERTIES heap_pris = dev->GetCustomHeapProperties(0, heap_type);
251
252   ComPtr<ID3D12Resource> res;
253   if (FAILED(dev->CreateCommittedResource(&heap_pris,
254       D3D12_HEAP_FLAG_NONE, &desc, D3D12_RESOURCE_STATE_COMMON,
255       NULL, __uuidof(ID3D12Resource), (void **)&res)))
256      throw runtime_error("CreateCommittedResource failed");
257
258   return res;
259}
260
261ComPtr<ID3D12Resource>
262ComputeTest::create_upload_buffer_with_data(const void *data, size_t size)
263{
264   auto upload_res = create_buffer(size, D3D12_HEAP_TYPE_UPLOAD);
265
266   void *ptr = NULL;
267   D3D12_RANGE res_range = { 0, (SIZE_T)size };
268   if (FAILED(upload_res->Map(0, &res_range, (void **)&ptr)))
269      throw runtime_error("Failed to map upload-buffer");
270   assert(ptr);
271   memcpy(ptr, data, size);
272   upload_res->Unmap(0, &res_range);
273   return upload_res;
274}
275
276ComPtr<ID3D12Resource>
277ComputeTest::create_sized_buffer_with_data(size_t buffer_size,
278                                           const void *data,
279                                           size_t data_size)
280{
281   auto upload_res = create_upload_buffer_with_data(data, data_size);
282
283   auto res = create_buffer(buffer_size, D3D12_HEAP_TYPE_DEFAULT);
284   resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_DEST);
285   cmdlist->CopyBufferRegion(res.Get(), 0, upload_res.Get(), 0, data_size);
286   resource_barrier(res, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_COMMON);
287   execute_cmdlist();
288
289   return res;
290}
291
292void
293ComputeTest::get_buffer_data(ComPtr<ID3D12Resource> res,
294                             void *buf, size_t size)
295{
296   auto readback_res = create_buffer(align(size, 4), D3D12_HEAP_TYPE_READBACK);
297   resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_SOURCE);
298   cmdlist->CopyResource(readback_res.Get(), res.Get());
299   resource_barrier(res, D3D12_RESOURCE_STATE_COPY_SOURCE, D3D12_RESOURCE_STATE_COMMON);
300   execute_cmdlist();
301
302   void *ptr = NULL;
303   D3D12_RANGE res_range = { 0, size };
304   if (FAILED(readback_res->Map(0, &res_range, &ptr)))
305      throw runtime_error("Failed to map readback-buffer");
306
307   memcpy(buf, ptr, size);
308
309   D3D12_RANGE empty_range = { 0, 0 };
310   readback_res->Unmap(0, &empty_range);
311}
312
313void
314ComputeTest::resource_barrier(ComPtr<ID3D12Resource> &res,
315                              D3D12_RESOURCE_STATES state_before,
316                              D3D12_RESOURCE_STATES state_after)
317{
318   D3D12_RESOURCE_BARRIER barrier;
319   barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
320   barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
321   barrier.Transition.pResource = res.Get();
322   barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
323   barrier.Transition.StateBefore = state_before;
324   barrier.Transition.StateAfter = state_after;
325   cmdlist->ResourceBarrier(1, &barrier);
326}
327
328void
329ComputeTest::execute_cmdlist()
330{
331   if (FAILED(cmdlist->Close()))
332      throw runtime_error("Closing ID3D12GraphicsCommandList failed");
333
334   ID3D12CommandList *cmdlists[] = { cmdlist };
335   cmdqueue->ExecuteCommandLists(1, cmdlists);
336   cmdqueue_fence->SetEventOnCompletion(fence_value, event);
337   cmdqueue->Signal(cmdqueue_fence, fence_value);
338   fence_value++;
339   WaitForSingleObject(event, INFINITE);
340
341   if (FAILED(cmdalloc->Reset()))
342      throw runtime_error("resetting ID3D12CommandAllocator failed");
343
344   if (FAILED(cmdlist->Reset(cmdalloc, NULL)))
345      throw runtime_error("resetting ID3D12GraphicsCommandList failed");
346}
347
348void
349ComputeTest::create_uav_buffer(ComPtr<ID3D12Resource> res,
350                               size_t width, size_t byte_stride,
351                               D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
352{
353   D3D12_UNORDERED_ACCESS_VIEW_DESC uav_desc;
354   uav_desc.Format = DXGI_FORMAT_R32_TYPELESS;
355   uav_desc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
356   uav_desc.Buffer.FirstElement = 0;
357   uav_desc.Buffer.NumElements = DIV_ROUND_UP(width * byte_stride, 4);
358   uav_desc.Buffer.StructureByteStride = 0;
359   uav_desc.Buffer.CounterOffsetInBytes = 0;
360   uav_desc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
361
362   dev->CreateUnorderedAccessView(res.Get(), NULL, &uav_desc, cpu_handle);
363}
364
365void
366ComputeTest::create_cbv(ComPtr<ID3D12Resource> res, size_t size,
367                        D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
368{
369   D3D12_CONSTANT_BUFFER_VIEW_DESC cbv_desc;
370   cbv_desc.BufferLocation = res ? res->GetGPUVirtualAddress() : 0;
371   cbv_desc.SizeInBytes = size;
372
373   dev->CreateConstantBufferView(&cbv_desc, cpu_handle);
374}
375
376ComPtr<ID3D12Resource>
377ComputeTest::add_uav_resource(ComputeTest::Resources &resources,
378                              unsigned spaceid, unsigned resid,
379                              const void *data, size_t num_elems,
380                              size_t elem_size)
381{
382   size_t size = align(elem_size * num_elems, 4);
383   D3D12_CPU_DESCRIPTOR_HANDLE handle;
384   ComPtr<ID3D12Resource> res;
385   handle = uav_heap->GetCPUDescriptorHandleForHeapStart();
386   handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
387
388   if (size) {
389      if (data)
390         res = create_buffer_with_data(data, size);
391      else
392         res = create_buffer(size, D3D12_HEAP_TYPE_DEFAULT);
393
394      resource_barrier(res, D3D12_RESOURCE_STATE_COMMON,
395                       D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
396   }
397   create_uav_buffer(res, num_elems, elem_size, handle);
398   resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_UAV, spaceid, resid);
399   return res;
400}
401
402ComPtr<ID3D12Resource>
403ComputeTest::add_cbv_resource(ComputeTest::Resources &resources,
404                              unsigned spaceid, unsigned resid,
405                              const void *data, size_t size)
406{
407   unsigned aligned_size = align(size, 256);
408   D3D12_CPU_DESCRIPTOR_HANDLE handle;
409   ComPtr<ID3D12Resource> res;
410   handle = uav_heap->GetCPUDescriptorHandleForHeapStart();
411   handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
412
413   if (size) {
414     assert(data);
415     res = create_sized_buffer_with_data(aligned_size, data, size);
416   }
417   create_cbv(res, aligned_size, handle);
418   resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_CBV, spaceid, resid);
419   return res;
420}
421
422void
423ComputeTest::run_shader_with_raw_args(Shader shader,
424                                      const CompileArgs &compile_args,
425                                      const std::vector<RawShaderArg *> &args)
426{
427   if (args.size() < 1)
428      throw runtime_error("no inputs");
429
430   static HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
431   if (!hD3D12Mod)
432      throw runtime_error("Failed to load D3D12.DLL");
433
434   D3D12SerializeVersionedRootSignature = (PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE)GetProcAddress(hD3D12Mod, "D3D12SerializeVersionedRootSignature");
435
436   if (args.size() != shader.dxil->kernel->num_args)
437      throw runtime_error("incorrect number of inputs");
438
439   struct clc_runtime_kernel_conf conf = { 0 };
440
441   // Older WARP and some hardware doesn't support int64, so for these tests, unconditionally lower away int64
442   // A more complex runtime can be smarter about detecting when this needs to be done
443   conf.lower_bit_size = 64;
444
445   if (!shader.dxil->metadata.local_size[0])
446      conf.local_size[0] = compile_args.x;
447   else
448      conf.local_size[0] = shader.dxil->metadata.local_size[0];
449
450   if (!shader.dxil->metadata.local_size[1])
451      conf.local_size[1] = compile_args.y;
452   else
453      conf.local_size[1] = shader.dxil->metadata.local_size[1];
454
455   if (!shader.dxil->metadata.local_size[2])
456      conf.local_size[2] = compile_args.z;
457   else
458      conf.local_size[2] = shader.dxil->metadata.local_size[2];
459
460   if (compile_args.x % conf.local_size[0] ||
461       compile_args.y % conf.local_size[1] ||
462       compile_args.z % conf.local_size[2])
463      throw runtime_error("invalid global size must be a multiple of local size");
464
465   std::vector<struct clc_runtime_arg_info> argsinfo(args.size());
466
467   conf.args = argsinfo.data();
468   conf.support_global_work_id_offsets =
469      compile_args.work_props.global_offset_x != 0 ||
470      compile_args.work_props.global_offset_y != 0 ||
471      compile_args.work_props.global_offset_z != 0;
472   conf.support_workgroup_id_offsets =
473      compile_args.work_props.group_id_offset_x != 0 ||
474      compile_args.work_props.group_id_offset_y != 0 ||
475      compile_args.work_props.group_id_offset_z != 0;
476
477   for (unsigned i = 0; i < shader.dxil->kernel->num_args; ++i) {
478      RawShaderArg *arg = args[i];
479      size_t size = arg->get_elem_size() * arg->get_num_elems();
480
481      switch (shader.dxil->kernel->args[i].address_qualifier) {
482      case CLC_KERNEL_ARG_ADDRESS_LOCAL:
483         argsinfo[i].localptr.size = size;
484         break;
485      default:
486         break;
487      }
488   }
489
490   configure(shader, &conf);
491   validate(shader);
492
493   std::shared_ptr<struct clc_dxil_object> &dxil = shader.dxil;
494
495   std::vector<uint8_t> argsbuf(dxil->metadata.kernel_inputs_buf_size);
496   std::vector<ComPtr<ID3D12Resource>> argres(shader.dxil->kernel->num_args);
497   clc_work_properties_data work_props = compile_args.work_props;
498   if (!conf.support_workgroup_id_offsets) {
499      work_props.group_count_total_x = compile_args.x / conf.local_size[0];
500      work_props.group_count_total_y = compile_args.y / conf.local_size[1];
501      work_props.group_count_total_z = compile_args.z / conf.local_size[2];
502   }
503   if (work_props.work_dim == 0)
504      work_props.work_dim = 3;
505   Resources resources;
506
507   for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
508      RawShaderArg *arg = args[i];
509      size_t size = arg->get_elem_size() * arg->get_num_elems();
510      void *slot = argsbuf.data() + dxil->metadata.args[i].offset;
511
512      switch (dxil->kernel->args[i].address_qualifier) {
513      case CLC_KERNEL_ARG_ADDRESS_CONSTANT:
514      case CLC_KERNEL_ARG_ADDRESS_GLOBAL: {
515         assert(dxil->metadata.args[i].size == sizeof(uint64_t));
516         uint64_t *ptr_slot = (uint64_t *)slot;
517         if (arg->get_data())
518            *ptr_slot = (uint64_t)dxil->metadata.args[i].globconstptr.buf_id << 32;
519         else
520            *ptr_slot = ~0ull;
521         break;
522      }
523      case CLC_KERNEL_ARG_ADDRESS_LOCAL: {
524         assert(dxil->metadata.args[i].size == sizeof(uint64_t));
525         uint64_t *ptr_slot = (uint64_t *)slot;
526         *ptr_slot = dxil->metadata.args[i].localptr.sharedmem_offset;
527         break;
528      }
529      case CLC_KERNEL_ARG_ADDRESS_PRIVATE: {
530         assert(size == dxil->metadata.args[i].size);
531         memcpy(slot, arg->get_data(), size);
532         break;
533      }
534      default:
535         assert(0);
536      }
537   }
538
539   for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
540      RawShaderArg *arg = args[i];
541
542      if (dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL ||
543          dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT) {
544         argres[i] = add_uav_resource(resources, 0,
545                                      dxil->metadata.args[i].globconstptr.buf_id,
546                                      arg->get_data(), arg->get_num_elems(),
547                                      arg->get_elem_size());
548      }
549   }
550
551   if (dxil->metadata.printf.uav_id > 0)
552      add_uav_resource(resources, 0, dxil->metadata.printf.uav_id, NULL, 1024 * 1024 / 4, 4);
553
554   for (unsigned i = 0; i < dxil->metadata.num_consts; ++i)
555      add_uav_resource(resources, 0, dxil->metadata.consts[i].uav_id,
556                       dxil->metadata.consts[i].data,
557                       dxil->metadata.consts[i].size / 4, 4);
558
559   if (argsbuf.size())
560      add_cbv_resource(resources, 0, dxil->metadata.kernel_inputs_cbv_id,
561                       argsbuf.data(), argsbuf.size());
562
563   add_cbv_resource(resources, 0, dxil->metadata.work_properties_cbv_id,
564                    &work_props, sizeof(work_props));
565
566   auto root_sig = create_root_signature(resources);
567   auto pipeline_state = create_pipeline_state(root_sig, *dxil);
568
569   cmdlist->SetDescriptorHeaps(1, &uav_heap);
570   cmdlist->SetComputeRootSignature(root_sig.Get());
571   cmdlist->SetComputeRootDescriptorTable(0, uav_heap->GetGPUDescriptorHandleForHeapStart());
572   cmdlist->SetPipelineState(pipeline_state.Get());
573
574   cmdlist->Dispatch(compile_args.x / conf.local_size[0],
575                     compile_args.y / conf.local_size[1],
576                     compile_args.z / conf.local_size[2]);
577
578   for (auto &range : resources.ranges) {
579      if (range.RangeType == D3D12_DESCRIPTOR_RANGE_TYPE_UAV) {
580         for (unsigned i = range.OffsetInDescriptorsFromTableStart;
581              i < range.NumDescriptors; i++) {
582            if (!resources.descs[i].Get())
583               continue;
584
585            resource_barrier(resources.descs[i],
586                             D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
587                             D3D12_RESOURCE_STATE_COMMON);
588         }
589      }
590   }
591
592   execute_cmdlist();
593
594   for (unsigned i = 0; i < args.size(); i++) {
595      if (!(args[i]->get_direction() & SHADER_ARG_OUTPUT))
596         continue;
597
598      assert(dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL);
599      get_buffer_data(argres[i], args[i]->get_data(),
600                      args[i]->get_elem_size() * args[i]->get_num_elems());
601   }
602
603   ComPtr<ID3D12InfoQueue> info_queue;
604   dev->QueryInterface(info_queue.ReleaseAndGetAddressOf());
605   if (info_queue)
606   {
607      EXPECT_EQ(0, info_queue->GetNumStoredMessages());
608      for (unsigned i = 0; i < info_queue->GetNumStoredMessages(); ++i) {
609         SIZE_T message_size = 0;
610         info_queue->GetMessageA(i, nullptr, &message_size);
611         D3D12_MESSAGE* message = (D3D12_MESSAGE*)malloc(message_size);
612         info_queue->GetMessageA(i, message, &message_size);
613         FAIL() << message->pDescription;
614         free(message);
615      }
616   }
617}
618
619void
620ComputeTest::SetUp()
621{
622   static struct clc_libclc *compiler_ctx_g = nullptr;
623
624   if (!compiler_ctx_g) {
625      clc_libclc_dxil_options options = { };
626      options.optimize = (debug_get_option_debug_compute() & COMPUTE_DEBUG_OPTIMIZE_LIBCLC) != 0;
627
628      compiler_ctx_g = clc_libclc_new_dxil(&logger, &options);
629      if (!compiler_ctx_g)
630         throw runtime_error("failed to create CLC compiler context");
631
632      if (debug_get_option_debug_compute() & COMPUTE_DEBUG_SERIALIZE_LIBCLC) {
633         void *serialized = nullptr;
634         size_t serialized_size = 0;
635         clc_libclc_serialize(compiler_ctx_g, &serialized, &serialized_size);
636         if (!serialized)
637            throw runtime_error("failed to serialize CLC compiler context");
638
639         clc_free_libclc(compiler_ctx_g);
640         compiler_ctx_g = nullptr;
641
642         compiler_ctx_g = clc_libclc_deserialize(serialized, serialized_size);
643         if (!compiler_ctx_g)
644            throw runtime_error("failed to deserialize CLC compiler context");
645
646         clc_libclc_free_serialized(serialized);
647      }
648   }
649   compiler_ctx = compiler_ctx_g;
650
651   enable_d3d12_debug_layer();
652
653   factory = get_dxgi_factory();
654   if (!factory)
655      throw runtime_error("failed to create DXGI factory");
656
657   adapter = choose_adapter(factory);
658   if (!adapter)
659      throw runtime_error("failed to choose adapter");
660
661   dev = create_device(adapter);
662   if (!dev)
663      throw runtime_error("failed to create device");
664
665   if (FAILED(dev->CreateFence(0, D3D12_FENCE_FLAG_NONE,
666                               __uuidof(cmdqueue_fence),
667                               (void **)&cmdqueue_fence)))
668      throw runtime_error("failed to create fence\n");
669
670   D3D12_COMMAND_QUEUE_DESC queue_desc;
671   queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
672   queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL;
673   queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
674   queue_desc.NodeMask = 0;
675   if (FAILED(dev->CreateCommandQueue(&queue_desc,
676                                      __uuidof(cmdqueue),
677                                      (void **)&cmdqueue)))
678      throw runtime_error("failed to create command queue");
679
680   if (FAILED(dev->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE,
681             __uuidof(cmdalloc), (void **)&cmdalloc)))
682      throw runtime_error("failed to create command allocator");
683
684   if (FAILED(dev->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE,
685             cmdalloc, NULL, __uuidof(cmdlist), (void **)&cmdlist)))
686      throw runtime_error("failed to create command list");
687
688   D3D12_DESCRIPTOR_HEAP_DESC heap_desc;
689   heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
690   heap_desc.NumDescriptors = 1000;
691   heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
692   heap_desc.NodeMask = 0;
693   if (FAILED(dev->CreateDescriptorHeap(&heap_desc,
694       __uuidof(uav_heap), (void **)&uav_heap)))
695      throw runtime_error("failed to create descriptor heap");
696
697   uav_heap_incr = dev->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
698
699   event = CreateEvent(NULL, FALSE, FALSE, NULL);
700   if (!event)
701      throw runtime_error("Failed to create event");
702   fence_value = 1;
703}
704
705void
706ComputeTest::TearDown()
707{
708   CloseHandle(event);
709
710   uav_heap->Release();
711   cmdlist->Release();
712   cmdalloc->Release();
713   cmdqueue->Release();
714   cmdqueue_fence->Release();
715   dev->Release();
716   adapter->Release();
717   factory->Release();
718}
719
720PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE ComputeTest::D3D12SerializeVersionedRootSignature;
721
722bool
723validate_module(const struct clc_dxil_object &dxil)
724{
725   static HMODULE hmod = LoadLibrary("DXIL.DLL");
726   if (!hmod) {
727      /* Enabling experimental shaders allows us to run unsigned shader code,
728       * such as when under the debugger where we can't run the validator. */
729      if (debug_get_option_debug_compute() & COMPUTE_DEBUG_EXPERIMENTAL_SHADERS)
730         return true;
731      else
732         throw runtime_error("failed to load DXIL.DLL");
733   }
734
735   DxcCreateInstanceProc pfnDxcCreateInstance =
736      (DxcCreateInstanceProc)GetProcAddress(hmod, "DxcCreateInstance");
737   if (!pfnDxcCreateInstance)
738      throw runtime_error("failed to load DxcCreateInstance");
739
740   struct shader_blob : public IDxcBlob {
741      shader_blob(void *data, size_t size) : data(data), size(size) {}
742      LPVOID STDMETHODCALLTYPE GetBufferPointer() override { return data; }
743      SIZE_T STDMETHODCALLTYPE GetBufferSize() override { return size; }
744      HRESULT STDMETHODCALLTYPE QueryInterface(REFIID, void **) override { return E_NOINTERFACE; }
745      ULONG STDMETHODCALLTYPE AddRef() override { return 1; }
746      ULONG STDMETHODCALLTYPE Release() override { return 0; }
747      void *data;
748      size_t size;
749   } blob(dxil.binary.data, dxil.binary.size);
750
751   IDxcValidator *validator;
752   if (FAILED(pfnDxcCreateInstance(CLSID_DxcValidator, __uuidof(IDxcValidator),
753                                   (void **)&validator)))
754      throw runtime_error("failed to create IDxcValidator");
755
756   IDxcOperationResult *result;
757   if (FAILED(validator->Validate(&blob, DxcValidatorFlags_InPlaceEdit,
758                                  &result)))
759      throw runtime_error("Validate failed");
760
761   HRESULT hr;
762   if (FAILED(result->GetStatus(&hr)) ||
763       FAILED(hr)) {
764      IDxcBlobEncoding *message;
765      result->GetErrorBuffer(&message);
766      fprintf(stderr, "D3D12: validation failed: %*s\n",
767                   (int)message->GetBufferSize(),
768                   (char *)message->GetBufferPointer());
769      message->Release();
770      validator->Release();
771      result->Release();
772      return false;
773   }
774
775   validator->Release();
776   result->Release();
777   return true;
778}
779
780static void
781dump_blob(const char *path, const struct clc_dxil_object &dxil)
782{
783   FILE *fp = fopen(path, "wb");
784   if (fp) {
785      fwrite(dxil.binary.data, 1, dxil.binary.size, fp);
786      fclose(fp);
787      printf("D3D12: wrote '%s'...\n", path);
788   }
789}
790
791ComputeTest::Shader
792ComputeTest::compile(const std::vector<const char *> &sources,
793                     const std::vector<const char *> &compile_args,
794                     bool create_library)
795{
796   struct clc_compile_args args = {
797   };
798   args.args = compile_args.data();
799   args.num_args = (unsigned)compile_args.size();
800   ComputeTest::Shader shader;
801
802   std::vector<Shader> shaders;
803
804   args.source.name = "obj.cl";
805
806   for (unsigned i = 0; i < sources.size(); i++) {
807      args.source.value = sources[i];
808
809      clc_binary spirv{};
810      if (!clc_compile_c_to_spirv(&args, &logger, &spirv))
811         throw runtime_error("failed to compile object!");
812
813      Shader shader;
814      shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
815         {
816            clc_free_spirv(spirv);
817            delete spirv;
818         });
819      shaders.push_back(shader);
820   }
821
822   if (shaders.size() == 1 && create_library)
823      return shaders[0];
824
825   return link(shaders, create_library);
826}
827
828ComputeTest::Shader
829ComputeTest::link(const std::vector<Shader> &sources,
830                  bool create_library)
831{
832   std::vector<const clc_binary*> objs;
833   for (auto& source : sources)
834      objs.push_back(&*source.obj);
835
836   struct clc_linker_args link_args = {};
837   link_args.in_objs = objs.data();
838   link_args.num_in_objs = (unsigned)objs.size();
839   link_args.create_library = create_library;
840   clc_binary spirv{};
841   if (!clc_link_spirv(&link_args, &logger, &spirv))
842      throw runtime_error("failed to link objects!");
843
844   ComputeTest::Shader shader;
845   shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
846      {
847         clc_free_spirv(spirv);
848         delete spirv;
849      });
850   if (!link_args.create_library)
851      configure(shader, NULL);
852
853   return shader;
854}
855
856ComputeTest::Shader
857ComputeTest::assemble(const char *source)
858{
859   spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
860   std::vector<uint32_t> binary;
861   if (!tools.Assemble(source, strlen(source), &binary))
862      throw runtime_error("failed to assemble");
863
864   ComputeTest::Shader shader;
865   shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary *spirv)
866      {
867         free(spirv->data);
868         delete spirv;
869      });
870   shader.obj->size = binary.size() * 4;
871   shader.obj->data = malloc(shader.obj->size);
872   memcpy(shader.obj->data, binary.data(), shader.obj->size);
873
874   configure(shader, NULL);
875
876   return shader;
877}
878
879void
880ComputeTest::configure(Shader &shader,
881                       const struct clc_runtime_kernel_conf *conf)
882{
883   if (!shader.metadata) {
884      shader.metadata = std::shared_ptr<clc_parsed_spirv>(new clc_parsed_spirv{}, [](clc_parsed_spirv *metadata)
885         {
886            clc_free_parsed_spirv(metadata);
887            delete metadata;
888         });
889      if (!clc_parse_spirv(shader.obj.get(), NULL, shader.metadata.get()))
890         throw runtime_error("failed to parse spirv!");
891   }
892
893   shader.dxil = std::shared_ptr<clc_dxil_object>(new clc_dxil_object{}, [](clc_dxil_object *dxil)
894      {
895         clc_free_dxil_object(dxil);
896         delete dxil;
897      });
898   if (!clc_spirv_to_dxil(compiler_ctx, shader.obj.get(), shader.metadata.get(), "main_test", conf, nullptr, &logger, shader.dxil.get()))
899      throw runtime_error("failed to compile kernel!");
900}
901
902void
903ComputeTest::validate(ComputeTest::Shader &shader)
904{
905   dump_blob("unsigned.cso", *shader.dxil);
906   if (!validate_module(*shader.dxil))
907      throw runtime_error("failed to validate module!");
908
909   dump_blob("signed.cso", *shader.dxil);
910}
911