1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "spirv_to_dxil.h"
25#include "nir_to_dxil.h"
26#include "dxil_nir.h"
27#include "shader_enums.h"
28#include "spirv/nir_spirv.h"
29#include "util/blob.h"
30
31#include "git_sha1.h"
32#include "vulkan/vulkan.h"
33
34static void
35shared_var_info(const struct glsl_type* type, unsigned* size, unsigned* align)
36{
37   assert(glsl_type_is_vector_or_scalar(type));
38
39   uint32_t comp_size = glsl_type_is_boolean(type) ? 4 : glsl_get_bit_size(type) / 8;
40   unsigned length = glsl_get_vector_elements(type);
41   *size = comp_size * length;
42   *align = comp_size;
43}
44
45static nir_variable *
46add_runtime_data_var(nir_shader *nir, unsigned desc_set, unsigned binding)
47{
48   unsigned runtime_data_size =
49      nir->info.stage == MESA_SHADER_COMPUTE
50         ? sizeof(struct dxil_spirv_compute_runtime_data)
51         : sizeof(struct dxil_spirv_vertex_runtime_data);
52
53   const struct glsl_type *array_type =
54      glsl_array_type(glsl_uint_type(), runtime_data_size / sizeof(unsigned),
55                      sizeof(unsigned));
56   const struct glsl_struct_field field = {array_type, "arr"};
57   nir_variable *var = nir_variable_create(
58      nir, nir_var_mem_ubo,
59      glsl_struct_type(&field, 1, "runtime_data", false), "runtime_data");
60   var->data.descriptor_set = desc_set;
61   // Check that desc_set fits on descriptor_set
62   assert(var->data.descriptor_set == desc_set);
63   var->data.binding = binding;
64   var->data.how_declared = nir_var_hidden;
65   return var;
66}
67
68struct lower_system_values_data {
69   nir_address_format ubo_format;
70   unsigned desc_set;
71   unsigned binding;
72};
73
74static bool
75lower_shader_system_values(struct nir_builder *builder, nir_instr *instr,
76                           void *cb_data)
77{
78   if (instr->type != nir_instr_type_intrinsic) {
79      return false;
80   }
81
82   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
83
84   /* All the intrinsics we care about are loads */
85   if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
86      return false;
87
88   assert(intrin->dest.is_ssa);
89
90   int offset = 0;
91   switch (intrin->intrinsic) {
92   case nir_intrinsic_load_num_workgroups:
93      offset =
94         offsetof(struct dxil_spirv_compute_runtime_data, group_count_x);
95      break;
96   case nir_intrinsic_load_first_vertex:
97      offset = offsetof(struct dxil_spirv_vertex_runtime_data, first_vertex);
98      break;
99   case nir_intrinsic_load_is_indexed_draw:
100      offset =
101         offsetof(struct dxil_spirv_vertex_runtime_data, is_indexed_draw);
102      break;
103   case nir_intrinsic_load_base_instance:
104      offset = offsetof(struct dxil_spirv_vertex_runtime_data, base_instance);
105      break;
106   default:
107      return false;
108   }
109
110   struct lower_system_values_data *data =
111      (struct lower_system_values_data *)cb_data;
112
113   builder->cursor = nir_after_instr(instr);
114   nir_address_format ubo_format = data->ubo_format;
115
116   nir_ssa_def *index = nir_vulkan_resource_index(
117      builder, nir_address_format_num_components(ubo_format),
118      nir_address_format_bit_size(ubo_format),
119      nir_imm_int(builder, 0),
120      .desc_set = data->desc_set, .binding = data->binding,
121      .desc_type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
122
123   nir_ssa_def *load_desc = nir_load_vulkan_descriptor(
124      builder, nir_address_format_num_components(ubo_format),
125      nir_address_format_bit_size(ubo_format),
126      index, .desc_type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
127
128   nir_ssa_def *load_data = build_load_ubo_dxil(
129      builder, nir_channel(builder, load_desc, 0),
130      nir_imm_int(builder, offset),
131      nir_dest_num_components(intrin->dest), nir_dest_bit_size(intrin->dest));
132
133   nir_ssa_def_rewrite_uses(&intrin->dest.ssa, load_data);
134   nir_instr_remove(instr);
135   return true;
136}
137
138static bool
139dxil_spirv_nir_lower_shader_system_values(nir_shader *shader,
140                                          nir_address_format ubo_format,
141                                          unsigned desc_set, unsigned binding)
142{
143   struct lower_system_values_data data = {
144      .ubo_format = ubo_format,
145      .desc_set = desc_set,
146      .binding = binding,
147   };
148   return nir_shader_instructions_pass(shader, lower_shader_system_values,
149                                       nir_metadata_block_index |
150                                          nir_metadata_dominance |
151                                          nir_metadata_loop_analysis,
152                                       &data);
153}
154
155bool
156spirv_to_dxil(const uint32_t *words, size_t word_count,
157              struct dxil_spirv_specialization *specializations,
158              unsigned int num_specializations, dxil_spirv_shader_stage stage,
159              const char *entry_point_name,
160              const struct dxil_spirv_runtime_conf *conf,
161              struct dxil_spirv_object *out_dxil)
162{
163   if (stage == MESA_SHADER_NONE || stage == MESA_SHADER_KERNEL)
164      return false;
165
166   struct spirv_to_nir_options spirv_opts = {
167      .ubo_addr_format = nir_address_format_32bit_index_offset,
168      .ssbo_addr_format = nir_address_format_32bit_index_offset,
169      .shared_addr_format = nir_address_format_32bit_offset_as_64bit,
170
171      // use_deref_buffer_array_length + nir_lower_explicit_io force
172      //  get_ssbo_size to take in the return from load_vulkan_descriptor
173      //  instead of vulkan_resource_index. This makes it much easier to
174      //  get the DXIL handle for the SSBO.
175      .use_deref_buffer_array_length = true
176   };
177
178   glsl_type_singleton_init_or_ref();
179
180   struct nir_shader_compiler_options nir_options = *dxil_get_nir_compiler_options();
181   // We will manually handle base_vertex when vertex_id and instance_id have
182   // have been already converted to zero-base.
183   nir_options.lower_base_vertex = !conf->zero_based_vertex_instance_id;
184
185   nir_shader *nir = spirv_to_nir(
186      words, word_count, (struct nir_spirv_specialization *)specializations,
187      num_specializations, (gl_shader_stage)stage, entry_point_name,
188      &spirv_opts, &nir_options);
189   if (!nir) {
190      glsl_type_singleton_decref();
191      return false;
192   }
193
194   nir_validate_shader(nir,
195                       "Validate before feeding NIR to the DXIL compiler");
196
197   const struct nir_lower_sysvals_to_varyings_options sysvals_to_varyings = {
198      .frag_coord = true,
199      .point_coord = true,
200   };
201   NIR_PASS_V(nir, nir_lower_sysvals_to_varyings, &sysvals_to_varyings);
202
203   NIR_PASS_V(nir, nir_lower_system_values);
204
205   if (conf->zero_based_vertex_instance_id) {
206      // vertex_id and instance_id should have already been transformed to
207      // base zero before spirv_to_dxil was called. Therefore, we can zero out
208      // base/firstVertex/Instance.
209      gl_system_value system_values[] = {SYSTEM_VALUE_FIRST_VERTEX,
210                                         SYSTEM_VALUE_BASE_VERTEX,
211                                         SYSTEM_VALUE_BASE_INSTANCE};
212      NIR_PASS_V(nir, dxil_nir_lower_system_values_to_zero, system_values,
213                 ARRAY_SIZE(system_values));
214   }
215
216   bool requires_runtime_data = false;
217   NIR_PASS(requires_runtime_data, nir,
218            dxil_spirv_nir_lower_shader_system_values,
219            spirv_opts.ubo_addr_format,
220            conf->runtime_data_cbv.register_space,
221            conf->runtime_data_cbv.base_shader_register);
222   if (requires_runtime_data) {
223      add_runtime_data_var(nir, conf->runtime_data_cbv.register_space,
224                           conf->runtime_data_cbv.base_shader_register);
225   }
226
227   NIR_PASS_V(nir, nir_split_per_member_structs);
228
229   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ubo | nir_var_mem_ssbo,
230              nir_address_format_32bit_index_offset);
231
232   if (!nir->info.shared_memory_explicit_layout) {
233      NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared,
234                 shared_var_info);
235   }
236   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_shared,
237      nir_address_format_32bit_offset_as_64bit);
238
239   nir_variable_mode nir_var_function_temp =
240      nir_var_shader_in | nir_var_shader_out;
241   NIR_PASS_V(nir, nir_lower_variable_initializers,
242              nir_var_function_temp);
243   NIR_PASS_V(nir, nir_opt_deref);
244   NIR_PASS_V(nir, nir_lower_returns);
245   NIR_PASS_V(nir, nir_inline_functions);
246   NIR_PASS_V(nir, nir_lower_variable_initializers,
247              ~nir_var_function_temp);
248
249   // Pick off the single entrypoint that we want.
250   nir_function *entrypoint;
251   foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
252      if (func->is_entrypoint)
253         entrypoint = func;
254      else
255         exec_node_remove(&func->node);
256   }
257   assert(exec_list_length(&nir->functions) == 1);
258
259   NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
260   NIR_PASS_V(nir, nir_lower_io_to_temporaries, entrypoint->impl, true, true);
261   NIR_PASS_V(nir, nir_lower_global_vars_to_local);
262   NIR_PASS_V(nir, nir_split_var_copies);
263   NIR_PASS_V(nir, nir_lower_var_copies);
264   NIR_PASS_V(nir, nir_lower_io_arrays_to_elements_no_indirects, false);
265
266   NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
267   NIR_PASS_V(nir, nir_opt_dce);
268   NIR_PASS_V(nir, dxil_nir_lower_double_math);
269
270   {
271      bool progress;
272      do
273      {
274         progress = false;
275         NIR_PASS(progress, nir, nir_copy_prop);
276         NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
277         NIR_PASS(progress, nir, nir_opt_deref);
278         NIR_PASS(progress, nir, nir_opt_dce);
279         NIR_PASS(progress, nir, nir_opt_undef);
280         NIR_PASS(progress, nir, nir_opt_constant_folding);
281         NIR_PASS(progress, nir, nir_opt_cse);
282         if (nir_opt_trivial_continues(nir)) {
283            progress = true;
284            NIR_PASS(progress, nir, nir_copy_prop);
285            NIR_PASS(progress, nir, nir_opt_dce);
286         }
287         NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
288         NIR_PASS(progress, nir, nir_opt_algebraic);
289      } while (progress);
290   }
291
292   NIR_PASS_V(nir, nir_lower_readonly_images_to_tex, true);
293   nir_lower_tex_options lower_tex_options = {0};
294   NIR_PASS_V(nir, nir_lower_tex, &lower_tex_options);
295
296   NIR_PASS_V(nir, dxil_nir_split_clip_cull_distance);
297   NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
298   NIR_PASS_V(nir, dxil_nir_create_bare_samplers);
299   NIR_PASS_V(nir, dxil_nir_lower_bool_input);
300
301   nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
302
303   nir->info.inputs_read =
304      dxil_reassign_driver_locations(nir, nir_var_shader_in, 0);
305
306   if (stage != MESA_SHADER_FRAGMENT) {
307      nir->info.outputs_written =
308         dxil_reassign_driver_locations(nir, nir_var_shader_out, 0);
309   } else {
310      dxil_sort_ps_outputs(nir);
311   }
312
313   struct nir_to_dxil_options opts = {.vulkan_environment = true};
314
315   struct blob dxil_blob;
316   if (!nir_to_dxil(nir, &opts, &dxil_blob)) {
317      if (dxil_blob.allocated)
318         blob_finish(&dxil_blob);
319      glsl_type_singleton_decref();
320      return false;
321   }
322
323   out_dxil->metadata.requires_runtime_data = requires_runtime_data;
324   blob_finish_get_buffer(&dxil_blob, &out_dxil->binary.buffer,
325                          &out_dxil->binary.size);
326
327   glsl_type_singleton_decref();
328   return true;
329}
330
331void
332spirv_to_dxil_free(struct dxil_spirv_object *dxil)
333{
334   free(dxil->binary.buffer);
335}
336
337uint64_t
338spirv_to_dxil_get_version()
339{
340   const char sha1[] = MESA_GIT_SHA1;
341   const char* dash = strchr(sha1, '-');
342   if (dash) {
343      return strtoull(dash + 1, NULL, 16);
344   }
345   return 0;
346}
347