1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include <stdio.h>
25#include <stdint.h>
26#include <stdexcept>
27
28#include <directx/d3d12.h>
29#include <dxgi1_4.h>
30#include <gtest/gtest.h>
31#include <wrl.h>
32
33#include "clc_compiler.h"
34
35using std::runtime_error;
36using Microsoft::WRL::ComPtr;
37
38inline D3D12_CPU_DESCRIPTOR_HANDLE
39offset_cpu_handle(D3D12_CPU_DESCRIPTOR_HANDLE handle, UINT offset)
40{
41   handle.ptr += offset;
42   return handle;
43}
44
45inline size_t
46align(size_t value, unsigned alignment)
47{
48   assert(alignment > 0);
49   return ((value + (alignment - 1)) / alignment) * alignment;
50}
51
52class ComputeTest : public ::testing::Test {
53protected:
54   struct Shader {
55      std::shared_ptr<struct clc_binary> obj;
56      std::shared_ptr<struct clc_parsed_spirv> metadata;
57      std::shared_ptr<struct clc_dxil_object> dxil;
58   };
59
60   static void
61   enable_d3d12_debug_layer();
62
63   static IDXGIFactory4 *
64   get_dxgi_factory();
65
66   static IDXGIAdapter1 *
67   choose_adapter(IDXGIFactory4 *factory);
68
69   static ID3D12Device *
70   create_device(IDXGIAdapter1 *adapter);
71
72   struct Resources {
73      void add(ComPtr<ID3D12Resource> res,
74               D3D12_DESCRIPTOR_RANGE_TYPE type,
75               unsigned spaceid,
76               unsigned resid)
77      {
78         descs.push_back(res);
79
80         if(!ranges.empty() &&
81            ranges.back().RangeType == type &&
82            ranges.back().RegisterSpace == spaceid &&
83            ranges.back().BaseShaderRegister + ranges.back().NumDescriptors == resid) {
84            ranges.back().NumDescriptors++;
85	    return;
86         }
87
88         D3D12_DESCRIPTOR_RANGE1 range;
89
90         range.RangeType = type;
91         range.NumDescriptors = 1;
92         range.BaseShaderRegister = resid;
93         range.RegisterSpace = spaceid;
94         range.OffsetInDescriptorsFromTableStart = descs.size() - 1;
95         range.Flags = D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS;
96         ranges.push_back(range);
97      }
98
99      std::vector<D3D12_DESCRIPTOR_RANGE1> ranges;
100      std::vector<ComPtr<ID3D12Resource>> descs;
101   };
102
103   ComPtr<ID3D12RootSignature>
104   create_root_signature(const Resources &resources);
105
106   ComPtr<ID3D12PipelineState>
107   create_pipeline_state(ComPtr<ID3D12RootSignature> &root_sig,
108                         const struct clc_dxil_object &dxil);
109
110   ComPtr<ID3D12Resource>
111   create_buffer(int size, D3D12_HEAP_TYPE heap_type);
112
113   ComPtr<ID3D12Resource>
114   create_upload_buffer_with_data(const void *data, size_t size);
115
116   ComPtr<ID3D12Resource>
117   create_sized_buffer_with_data(size_t buffer_size, const void *data,
118                                 size_t data_size);
119
120   ComPtr<ID3D12Resource>
121   create_buffer_with_data(const void *data, size_t size)
122   {
123      return create_sized_buffer_with_data(size, data, size);
124   }
125
126   void
127   get_buffer_data(ComPtr<ID3D12Resource> res,
128                   void *buf, size_t size);
129
130   void
131   resource_barrier(ComPtr<ID3D12Resource> &res,
132                    D3D12_RESOURCE_STATES state_before,
133                    D3D12_RESOURCE_STATES state_after);
134
135   void
136   execute_cmdlist();
137
138   void
139   create_uav_buffer(ComPtr<ID3D12Resource> res,
140                     size_t width, size_t byte_stride,
141                     D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle);
142
143   void create_cbv(ComPtr<ID3D12Resource> res, size_t size,
144                   D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle);
145
146   ComPtr<ID3D12Resource>
147   add_uav_resource(Resources &resources, unsigned spaceid, unsigned resid,
148                    const void *data = NULL, size_t num_elems = 0,
149                    size_t elem_size = 0);
150
151   ComPtr<ID3D12Resource>
152   add_cbv_resource(Resources &resources, unsigned spaceid, unsigned resid,
153                    const void *data, size_t size);
154
155   void
156   SetUp() override;
157
158   void
159   TearDown() override;
160
161   Shader
162   compile(const std::vector<const char *> &sources,
163           const std::vector<const char *> &compile_args = {},
164           bool create_library = false);
165
166   Shader
167   link(const std::vector<Shader> &sources,
168        bool create_library = false);
169
170   Shader
171   assemble(const char *source);
172
173   void
174   configure(Shader &shader,
175             const struct clc_runtime_kernel_conf *conf);
176
177   void
178   validate(Shader &shader);
179
180   template <typename T>
181   Shader
182   specialize(Shader &shader, uint32_t id, T const& val)
183   {
184      Shader new_shader;
185      new_shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary *spirv)
186         {
187            clc_free_spirv(spirv);
188            delete spirv;
189         });
190      if (!shader.metadata)
191         configure(shader, NULL);
192
193      clc_spirv_specialization spec;
194      spec.id = id;
195      memcpy(&spec.value, &val, sizeof(val));
196      clc_spirv_specialization_consts consts;
197      consts.specializations = &spec;
198      consts.num_specializations = 1;
199      if (!clc_specialize_spirv(shader.obj.get(), shader.metadata.get(), &consts, new_shader.obj.get()))
200         throw runtime_error("failed to specialize");
201
202      configure(new_shader, NULL);
203
204      return new_shader;
205   }
206
207   enum ShaderArgDirection {
208      SHADER_ARG_INPUT = 1,
209      SHADER_ARG_OUTPUT = 2,
210      SHADER_ARG_INOUT = SHADER_ARG_INPUT | SHADER_ARG_OUTPUT,
211   };
212
213   class RawShaderArg {
214   public:
215      RawShaderArg(enum ShaderArgDirection dir) : dir(dir) { }
216      virtual size_t get_elem_size() const = 0;
217      virtual size_t get_num_elems() const = 0;
218      virtual const void *get_data() const = 0;
219      virtual void *get_data() = 0;
220      enum ShaderArgDirection get_direction() { return dir; }
221   private:
222      enum ShaderArgDirection dir;
223   };
224
225   class NullShaderArg : public RawShaderArg {
226   public:
227      NullShaderArg() : RawShaderArg(SHADER_ARG_INPUT) { }
228      size_t get_elem_size() const override { return 0; }
229      size_t get_num_elems() const override { return 0; }
230      const void *get_data() const override { return NULL; }
231      void *get_data() override { return NULL; }
232   };
233
234   template <typename T>
235   class ShaderArg : public std::vector<T>, public RawShaderArg
236   {
237   public:
238      ShaderArg(const T &v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) :
239         std::vector<T>({ v }), RawShaderArg(dir) { }
240      ShaderArg(const std::vector<T> &v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) :
241         std::vector<T>(v), RawShaderArg(dir) { }
242      ShaderArg(const std::initializer_list<T> v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) :
243         std::vector<T>(v), RawShaderArg(dir) { }
244
245      ShaderArg<T>& operator =(const T &v)
246      {
247         this->clear();
248	 this->push_back(v);
249         return *this;
250      }
251
252      operator T&() { return this->at(0); }
253      operator const T&() const { return this->at(0); }
254
255      ShaderArg<T>& operator =(const std::vector<T> &v)
256      {
257	 *this = v;
258         return *this;
259      }
260
261      ShaderArg<T>& operator =(std::initializer_list<T> v)
262      {
263	 *this = v;
264         return *this;
265      }
266
267      size_t get_elem_size() const override { return sizeof(T); }
268      size_t get_num_elems() const override { return this->size(); }
269      const void *get_data() const override { return this->data(); }
270      void *get_data() override { return this->data(); }
271   };
272
273   struct CompileArgs
274   {
275      unsigned x, y, z;
276      std::vector<const char *> compiler_command_line;
277      clc_work_properties_data work_props;
278   };
279
280private:
281   void gather_args(std::vector<RawShaderArg *> &args) { }
282
283   template <typename T, typename... Rest>
284   void gather_args(std::vector<RawShaderArg *> &args, T &arg, Rest&... rest)
285   {
286      args.push_back(&arg);
287      gather_args(args, rest...);
288   }
289
290   void run_shader_with_raw_args(Shader shader,
291                                 const CompileArgs &compile_args,
292                                 const std::vector<RawShaderArg *> &args);
293
294protected:
295   template <typename... Args>
296   void run_shader(Shader shader,
297                   const CompileArgs &compile_args,
298                   Args&... args)
299   {
300      std::vector<RawShaderArg *> raw_args;
301      gather_args(raw_args, args...);
302      run_shader_with_raw_args(shader, compile_args, raw_args);
303   }
304
305   template <typename... Args>
306   void run_shader(const std::vector<const char *> &sources,
307                   unsigned x, unsigned y, unsigned z,
308                   Args&... args)
309   {
310      std::vector<RawShaderArg *> raw_args;
311      gather_args(raw_args, args...);
312      CompileArgs compile_args = { x, y, z };
313      run_shader_with_raw_args(compile(sources), compile_args, raw_args);
314   }
315
316   template <typename... Args>
317   void run_shader(const std::vector<const char *> &sources,
318                   const CompileArgs &compile_args,
319                   Args&... args)
320   {
321      std::vector<RawShaderArg *> raw_args;
322      gather_args(raw_args, args...);
323      run_shader_with_raw_args(
324         compile(sources, compile_args.compiler_command_line),
325         compile_args, raw_args);
326   }
327
328   template <typename... Args>
329   void run_shader(const char *source,
330                   unsigned x, unsigned y, unsigned z,
331                   Args&... args)
332   {
333      std::vector<RawShaderArg *> raw_args;
334      gather_args(raw_args, args...);
335      CompileArgs compile_args = { x, y, z };
336      run_shader_with_raw_args(compile({ source }), compile_args, raw_args);
337   }
338
339   IDXGIFactory4 *factory;
340   IDXGIAdapter1 *adapter;
341   ID3D12Device *dev;
342   ID3D12Fence *cmdqueue_fence;
343   ID3D12CommandQueue *cmdqueue;
344   ID3D12CommandAllocator *cmdalloc;
345   ID3D12GraphicsCommandList *cmdlist;
346   ID3D12DescriptorHeap *uav_heap;
347
348   struct clc_libclc *compiler_ctx;
349
350   UINT uav_heap_incr;
351   int fence_value;
352
353   HANDLE event;
354   static PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE D3D12SerializeVersionedRootSignature;
355};
356