1/* 2 * Copyright © Microsoft Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#include <stdio.h> 25#include <stdint.h> 26#include <stdexcept> 27 28#include <directx/d3d12.h> 29#include <dxgi1_4.h> 30#include <gtest/gtest.h> 31#include <wrl.h> 32 33#include "clc_compiler.h" 34 35using std::runtime_error; 36using Microsoft::WRL::ComPtr; 37 38inline D3D12_CPU_DESCRIPTOR_HANDLE 39offset_cpu_handle(D3D12_CPU_DESCRIPTOR_HANDLE handle, UINT offset) 40{ 41 handle.ptr += offset; 42 return handle; 43} 44 45inline size_t 46align(size_t value, unsigned alignment) 47{ 48 assert(alignment > 0); 49 return ((value + (alignment - 1)) / alignment) * alignment; 50} 51 52class ComputeTest : public ::testing::Test { 53protected: 54 struct Shader { 55 std::shared_ptr<struct clc_binary> obj; 56 std::shared_ptr<struct clc_parsed_spirv> metadata; 57 std::shared_ptr<struct clc_dxil_object> dxil; 58 }; 59 60 static void 61 enable_d3d12_debug_layer(); 62 63 static IDXGIFactory4 * 64 get_dxgi_factory(); 65 66 static IDXGIAdapter1 * 67 choose_adapter(IDXGIFactory4 *factory); 68 69 static ID3D12Device * 70 create_device(IDXGIAdapter1 *adapter); 71 72 struct Resources { 73 void add(ComPtr<ID3D12Resource> res, 74 D3D12_DESCRIPTOR_RANGE_TYPE type, 75 unsigned spaceid, 76 unsigned resid) 77 { 78 descs.push_back(res); 79 80 if(!ranges.empty() && 81 ranges.back().RangeType == type && 82 ranges.back().RegisterSpace == spaceid && 83 ranges.back().BaseShaderRegister + ranges.back().NumDescriptors == resid) { 84 ranges.back().NumDescriptors++; 85 return; 86 } 87 88 D3D12_DESCRIPTOR_RANGE1 range; 89 90 range.RangeType = type; 91 range.NumDescriptors = 1; 92 range.BaseShaderRegister = resid; 93 range.RegisterSpace = spaceid; 94 range.OffsetInDescriptorsFromTableStart = descs.size() - 1; 95 range.Flags = D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS; 96 ranges.push_back(range); 97 } 98 99 std::vector<D3D12_DESCRIPTOR_RANGE1> ranges; 100 std::vector<ComPtr<ID3D12Resource>> descs; 101 }; 102 103 ComPtr<ID3D12RootSignature> 104 create_root_signature(const Resources &resources); 105 106 ComPtr<ID3D12PipelineState> 107 create_pipeline_state(ComPtr<ID3D12RootSignature> &root_sig, 108 const struct clc_dxil_object &dxil); 109 110 ComPtr<ID3D12Resource> 111 create_buffer(int size, D3D12_HEAP_TYPE heap_type); 112 113 ComPtr<ID3D12Resource> 114 create_upload_buffer_with_data(const void *data, size_t size); 115 116 ComPtr<ID3D12Resource> 117 create_sized_buffer_with_data(size_t buffer_size, const void *data, 118 size_t data_size); 119 120 ComPtr<ID3D12Resource> 121 create_buffer_with_data(const void *data, size_t size) 122 { 123 return create_sized_buffer_with_data(size, data, size); 124 } 125 126 void 127 get_buffer_data(ComPtr<ID3D12Resource> res, 128 void *buf, size_t size); 129 130 void 131 resource_barrier(ComPtr<ID3D12Resource> &res, 132 D3D12_RESOURCE_STATES state_before, 133 D3D12_RESOURCE_STATES state_after); 134 135 void 136 execute_cmdlist(); 137 138 void 139 create_uav_buffer(ComPtr<ID3D12Resource> res, 140 size_t width, size_t byte_stride, 141 D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle); 142 143 void create_cbv(ComPtr<ID3D12Resource> res, size_t size, 144 D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle); 145 146 ComPtr<ID3D12Resource> 147 add_uav_resource(Resources &resources, unsigned spaceid, unsigned resid, 148 const void *data = NULL, size_t num_elems = 0, 149 size_t elem_size = 0); 150 151 ComPtr<ID3D12Resource> 152 add_cbv_resource(Resources &resources, unsigned spaceid, unsigned resid, 153 const void *data, size_t size); 154 155 void 156 SetUp() override; 157 158 void 159 TearDown() override; 160 161 Shader 162 compile(const std::vector<const char *> &sources, 163 const std::vector<const char *> &compile_args = {}, 164 bool create_library = false); 165 166 Shader 167 link(const std::vector<Shader> &sources, 168 bool create_library = false); 169 170 Shader 171 assemble(const char *source); 172 173 void 174 configure(Shader &shader, 175 const struct clc_runtime_kernel_conf *conf); 176 177 void 178 validate(Shader &shader); 179 180 template <typename T> 181 Shader 182 specialize(Shader &shader, uint32_t id, T const& val) 183 { 184 Shader new_shader; 185 new_shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary *spirv) 186 { 187 clc_free_spirv(spirv); 188 delete spirv; 189 }); 190 if (!shader.metadata) 191 configure(shader, NULL); 192 193 clc_spirv_specialization spec; 194 spec.id = id; 195 memcpy(&spec.value, &val, sizeof(val)); 196 clc_spirv_specialization_consts consts; 197 consts.specializations = &spec; 198 consts.num_specializations = 1; 199 if (!clc_specialize_spirv(shader.obj.get(), shader.metadata.get(), &consts, new_shader.obj.get())) 200 throw runtime_error("failed to specialize"); 201 202 configure(new_shader, NULL); 203 204 return new_shader; 205 } 206 207 enum ShaderArgDirection { 208 SHADER_ARG_INPUT = 1, 209 SHADER_ARG_OUTPUT = 2, 210 SHADER_ARG_INOUT = SHADER_ARG_INPUT | SHADER_ARG_OUTPUT, 211 }; 212 213 class RawShaderArg { 214 public: 215 RawShaderArg(enum ShaderArgDirection dir) : dir(dir) { } 216 virtual size_t get_elem_size() const = 0; 217 virtual size_t get_num_elems() const = 0; 218 virtual const void *get_data() const = 0; 219 virtual void *get_data() = 0; 220 enum ShaderArgDirection get_direction() { return dir; } 221 private: 222 enum ShaderArgDirection dir; 223 }; 224 225 class NullShaderArg : public RawShaderArg { 226 public: 227 NullShaderArg() : RawShaderArg(SHADER_ARG_INPUT) { } 228 size_t get_elem_size() const override { return 0; } 229 size_t get_num_elems() const override { return 0; } 230 const void *get_data() const override { return NULL; } 231 void *get_data() override { return NULL; } 232 }; 233 234 template <typename T> 235 class ShaderArg : public std::vector<T>, public RawShaderArg 236 { 237 public: 238 ShaderArg(const T &v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) : 239 std::vector<T>({ v }), RawShaderArg(dir) { } 240 ShaderArg(const std::vector<T> &v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) : 241 std::vector<T>(v), RawShaderArg(dir) { } 242 ShaderArg(const std::initializer_list<T> v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) : 243 std::vector<T>(v), RawShaderArg(dir) { } 244 245 ShaderArg<T>& operator =(const T &v) 246 { 247 this->clear(); 248 this->push_back(v); 249 return *this; 250 } 251 252 operator T&() { return this->at(0); } 253 operator const T&() const { return this->at(0); } 254 255 ShaderArg<T>& operator =(const std::vector<T> &v) 256 { 257 *this = v; 258 return *this; 259 } 260 261 ShaderArg<T>& operator =(std::initializer_list<T> v) 262 { 263 *this = v; 264 return *this; 265 } 266 267 size_t get_elem_size() const override { return sizeof(T); } 268 size_t get_num_elems() const override { return this->size(); } 269 const void *get_data() const override { return this->data(); } 270 void *get_data() override { return this->data(); } 271 }; 272 273 struct CompileArgs 274 { 275 unsigned x, y, z; 276 std::vector<const char *> compiler_command_line; 277 clc_work_properties_data work_props; 278 }; 279 280private: 281 void gather_args(std::vector<RawShaderArg *> &args) { } 282 283 template <typename T, typename... Rest> 284 void gather_args(std::vector<RawShaderArg *> &args, T &arg, Rest&... rest) 285 { 286 args.push_back(&arg); 287 gather_args(args, rest...); 288 } 289 290 void run_shader_with_raw_args(Shader shader, 291 const CompileArgs &compile_args, 292 const std::vector<RawShaderArg *> &args); 293 294protected: 295 template <typename... Args> 296 void run_shader(Shader shader, 297 const CompileArgs &compile_args, 298 Args&... args) 299 { 300 std::vector<RawShaderArg *> raw_args; 301 gather_args(raw_args, args...); 302 run_shader_with_raw_args(shader, compile_args, raw_args); 303 } 304 305 template <typename... Args> 306 void run_shader(const std::vector<const char *> &sources, 307 unsigned x, unsigned y, unsigned z, 308 Args&... args) 309 { 310 std::vector<RawShaderArg *> raw_args; 311 gather_args(raw_args, args...); 312 CompileArgs compile_args = { x, y, z }; 313 run_shader_with_raw_args(compile(sources), compile_args, raw_args); 314 } 315 316 template <typename... Args> 317 void run_shader(const std::vector<const char *> &sources, 318 const CompileArgs &compile_args, 319 Args&... args) 320 { 321 std::vector<RawShaderArg *> raw_args; 322 gather_args(raw_args, args...); 323 run_shader_with_raw_args( 324 compile(sources, compile_args.compiler_command_line), 325 compile_args, raw_args); 326 } 327 328 template <typename... Args> 329 void run_shader(const char *source, 330 unsigned x, unsigned y, unsigned z, 331 Args&... args) 332 { 333 std::vector<RawShaderArg *> raw_args; 334 gather_args(raw_args, args...); 335 CompileArgs compile_args = { x, y, z }; 336 run_shader_with_raw_args(compile({ source }), compile_args, raw_args); 337 } 338 339 IDXGIFactory4 *factory; 340 IDXGIAdapter1 *adapter; 341 ID3D12Device *dev; 342 ID3D12Fence *cmdqueue_fence; 343 ID3D12CommandQueue *cmdqueue; 344 ID3D12CommandAllocator *cmdalloc; 345 ID3D12GraphicsCommandList *cmdlist; 346 ID3D12DescriptorHeap *uav_heap; 347 348 struct clc_libclc *compiler_ctx; 349 350 UINT uav_heap_incr; 351 int fence_value; 352 353 HANDLE event; 354 static PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE D3D12SerializeVersionedRootSignature; 355}; 356