brw_nir_rt.c revision 7ec681f3
1/*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "brw_nir_rt.h"
25#include "brw_nir_rt_builder.h"
26
27static bool
28resize_deref(nir_builder *b, nir_deref_instr *deref,
29             unsigned num_components, unsigned bit_size)
30{
31   assert(deref->dest.is_ssa);
32   if (deref->dest.ssa.num_components == num_components &&
33       deref->dest.ssa.bit_size == bit_size)
34      return false;
35
36   /* NIR requires array indices have to match the deref bit size */
37   if (deref->dest.ssa.bit_size != bit_size &&
38       (deref->deref_type == nir_deref_type_array ||
39        deref->deref_type == nir_deref_type_ptr_as_array)) {
40      b->cursor = nir_before_instr(&deref->instr);
41      assert(deref->arr.index.is_ssa);
42      nir_ssa_def *idx;
43      if (nir_src_is_const(deref->arr.index)) {
44         idx = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index), bit_size);
45      } else {
46         idx = nir_i2i(b, deref->arr.index.ssa, bit_size);
47      }
48      nir_instr_rewrite_src(&deref->instr, &deref->arr.index,
49                            nir_src_for_ssa(idx));
50   }
51
52   deref->dest.ssa.num_components = num_components;
53   deref->dest.ssa.bit_size = bit_size;
54
55   return true;
56}
57
58static bool
59lower_rt_io_derefs(nir_shader *shader)
60{
61   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
62
63   bool progress = false;
64
65   unsigned num_shader_call_vars = 0;
66   nir_foreach_variable_with_modes(var, shader, nir_var_shader_call_data)
67      num_shader_call_vars++;
68
69   unsigned num_ray_hit_attrib_vars = 0;
70   nir_foreach_variable_with_modes(var, shader, nir_var_ray_hit_attrib)
71      num_ray_hit_attrib_vars++;
72
73   /* At most one payload is allowed because it's an input.  Technically, this
74    * is also true for hit attribute variables.  However, after we inline an
75    * any-hit shader into an intersection shader, we can end up with multiple
76    * hit attribute variables.  They'll end up mapping to a cast from the same
77    * base pointer so this is fine.
78    */
79   assert(num_shader_call_vars <= 1);
80
81   nir_builder b;
82   nir_builder_init(&b, impl);
83
84   b.cursor = nir_before_cf_list(&impl->body);
85   nir_ssa_def *call_data_addr = NULL;
86   if (num_shader_call_vars > 0) {
87      assert(shader->scratch_size >= BRW_BTD_STACK_CALLEE_DATA_SIZE);
88      call_data_addr =
89         brw_nir_rt_load_scratch(&b, BRW_BTD_STACK_CALL_DATA_PTR_OFFSET, 8,
90                                 1, 64);
91      progress = true;
92   }
93
94   gl_shader_stage stage = shader->info.stage;
95   nir_ssa_def *hit_attrib_addr = NULL;
96   if (num_ray_hit_attrib_vars > 0) {
97      assert(stage == MESA_SHADER_ANY_HIT ||
98             stage == MESA_SHADER_CLOSEST_HIT ||
99             stage == MESA_SHADER_INTERSECTION);
100      nir_ssa_def *hit_addr =
101         brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT);
102      /* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */
103      nir_ssa_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4);
104      hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b),
105                                      brw_nir_rt_hit_attrib_data_addr(&b),
106                                      bary_addr);
107      progress = true;
108   }
109
110   nir_foreach_block(block, impl) {
111      nir_foreach_instr_safe(instr, block) {
112         if (instr->type != nir_instr_type_deref)
113            continue;
114
115         nir_deref_instr *deref = nir_instr_as_deref(instr);
116         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
117            deref->modes = nir_var_function_temp;
118            if (deref->deref_type == nir_deref_type_var) {
119               b.cursor = nir_before_instr(&deref->instr);
120               nir_deref_instr *cast =
121                  nir_build_deref_cast(&b, call_data_addr,
122                                       nir_var_function_temp,
123                                       deref->var->type, 0);
124               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
125                                        &cast->dest.ssa);
126               nir_instr_remove(&deref->instr);
127               progress = true;
128            }
129         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
130            deref->modes = nir_var_function_temp;
131            if (deref->deref_type == nir_deref_type_var) {
132               b.cursor = nir_before_instr(&deref->instr);
133               nir_deref_instr *cast =
134                  nir_build_deref_cast(&b, hit_attrib_addr,
135                                       nir_var_function_temp,
136                                       deref->type, 0);
137               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
138                                        &cast->dest.ssa);
139               nir_instr_remove(&deref->instr);
140               progress = true;
141            }
142         }
143
144         /* We're going to lower all function_temp memory to scratch using
145          * 64-bit addresses.  We need to resize all our derefs first or else
146          * nir_lower_explicit_io will have a fit.
147          */
148         if (nir_deref_mode_is(deref, nir_var_function_temp) &&
149             resize_deref(&b, deref, 1, 64))
150            progress = true;
151      }
152   }
153
154   if (progress) {
155      nir_metadata_preserve(impl, nir_metadata_block_index |
156                                  nir_metadata_dominance);
157   } else {
158      nir_metadata_preserve(impl, nir_metadata_all);
159   }
160
161   return progress;
162}
163
164/** Lowers ray-tracing shader I/O and scratch access
165 *
166 * SPV_KHR_ray_tracing adds three new types of I/O, each of which need their
167 * own bit of special care:
168 *
169 *  - Shader payload data:  This is represented by the IncomingCallableData
170 *    and IncomingRayPayload storage classes which are both represented by
171 *    nir_var_call_data in NIR.  There is at most one of these per-shader and
172 *    they contain payload data passed down the stack from the parent shader
173 *    when it calls executeCallable() or traceRay().  In our implementation,
174 *    the actual storage lives in the calling shader's scratch space and we're
175 *    passed a pointer to it.
176 *
177 *  - Hit attribute data:  This is represented by the HitAttribute storage
178 *    class in SPIR-V and nir_var_ray_hit_attrib in NIR.  For triangle
179 *    geometry, it's supposed to contain two floats which are the barycentric
180 *    coordinates.  For AABS/procedural geometry, it contains the hit data
181 *    written out by the intersection shader.  In our implementation, it's a
182 *    64-bit pointer which points either to the u/v area of the relevant
183 *    MemHit data structure or the space right after the HW ray stack entry.
184 *
185 *  - Shader record buffer data:  This allows read-only access to the data
186 *    stored in the SBT right after the bindless shader handles.  It's
187 *    effectively a UBO with a magic address.  Coming out of spirv_to_nir,
188 *    we get a nir_intrinsic_load_shader_record_ptr which is cast to a
189 *    nir_var_mem_global deref and all access happens through that.  The
190 *    shader_record_ptr system value is handled in brw_nir_lower_rt_intrinsics
191 *    and we assume nir_lower_explicit_io is called elsewhere thanks to
192 *    VK_KHR_buffer_device_address so there's really nothing to do here.
193 *
194 * We also handle lowering any remaining function_temp variables to scratch at
195 * this point.  This gets rid of any remaining arrays and also takes care of
196 * the sending side of ray payloads where we pass pointers to a function_temp
197 * variable down the call stack.
198 */
199static void
200lower_rt_io_and_scratch(nir_shader *nir)
201{
202   /* First, we to ensure all the I/O variables have explicit types.  Because
203    * these are shader-internal and don't come in from outside, they don't
204    * have an explicit memory layout and we have to assign them one.
205    */
206   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
207              nir_var_function_temp |
208              nir_var_shader_call_data |
209              nir_var_ray_hit_attrib,
210              glsl_get_natural_size_align_bytes);
211
212   /* Now patch any derefs to I/O vars */
213   NIR_PASS_V(nir, lower_rt_io_derefs);
214
215   /* Finally, lower any remaining function_temp, mem_constant, or
216    * ray_hit_attrib access to 64-bit global memory access.
217    */
218   NIR_PASS_V(nir, nir_lower_explicit_io,
219              nir_var_function_temp |
220              nir_var_mem_constant |
221              nir_var_ray_hit_attrib,
222              nir_address_format_64bit_global);
223}
224
225static void
226build_terminate_ray(nir_builder *b)
227{
228   nir_ssa_def *skip_closest_hit =
229      nir_i2b(b, nir_iand_imm(b, nir_load_ray_flags(b),
230                              BRW_RT_RAY_FLAG_SKIP_CLOSEST_HIT_SHADER));
231   nir_push_if(b, skip_closest_hit);
232   {
233      /* The shader that calls traceRay() is unable to access any ray hit
234       * information except for that which is explicitly written into the ray
235       * payload by shaders invoked during the trace.  If there's no closest-
236       * hit shader, then accepting the hit has no observable effect; it's
237       * just extra memory traffic for no reason.
238       */
239      brw_nir_btd_return(b);
240      nir_jump(b, nir_jump_halt);
241   }
242   nir_push_else(b, NULL);
243   {
244      /* The closest hit shader is in the same shader group as the any-hit
245       * shader that we're currently in.  We can get the address for its SBT
246       * handle by looking at the shader record pointer and subtracting the
247       * size of a SBT handle.  The BINDLESS_SHADER_RECORD for a closest hit
248       * shader is the first one in the SBT handle.
249       */
250      nir_ssa_def *closest_hit =
251         nir_iadd_imm(b, nir_load_shader_record_ptr(b),
252                        -BRW_RT_SBT_HANDLE_SIZE);
253
254      brw_nir_rt_commit_hit(b);
255      brw_nir_btd_spawn(b, closest_hit);
256      nir_jump(b, nir_jump_halt);
257   }
258   nir_pop_if(b, NULL);
259}
260
261/** Lowers away ray walk intrinsics
262 *
263 * This lowers terminate_ray, ignore_ray_intersection, and the NIR-specific
264 * accept_ray_intersection intrinsics to the appropriate Intel-specific
265 * intrinsics.
266 */
267static bool
268lower_ray_walk_intrinsics(nir_shader *shader,
269                          const struct intel_device_info *devinfo)
270{
271   assert(shader->info.stage == MESA_SHADER_ANY_HIT ||
272          shader->info.stage == MESA_SHADER_INTERSECTION);
273
274   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
275
276   nir_builder b;
277   nir_builder_init(&b, impl);
278
279   bool progress = false;
280   nir_foreach_block_safe(block, impl) {
281      nir_foreach_instr_safe(instr, block) {
282         if (instr->type != nir_instr_type_intrinsic)
283            continue;
284
285         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
286
287         switch (intrin->intrinsic) {
288         case nir_intrinsic_ignore_ray_intersection: {
289            b.cursor = nir_instr_remove(&intrin->instr);
290
291            /* We put the newly emitted code inside a dummy if because it's
292             * going to contain a jump instruction and we don't want to deal
293             * with that mess here.  It'll get dealt with by our control-flow
294             * optimization passes.
295             */
296            nir_push_if(&b, nir_imm_true(&b));
297            nir_trace_ray_continue_intel(&b);
298            nir_jump(&b, nir_jump_halt);
299            nir_pop_if(&b, NULL);
300            progress = true;
301            break;
302         }
303
304         case nir_intrinsic_accept_ray_intersection: {
305            b.cursor = nir_instr_remove(&intrin->instr);
306
307            nir_ssa_def *terminate =
308               nir_i2b(&b, nir_iand_imm(&b, nir_load_ray_flags(&b),
309                                        BRW_RT_RAY_FLAG_TERMINATE_ON_FIRST_HIT));
310            nir_push_if(&b, terminate);
311            {
312               build_terminate_ray(&b);
313            }
314            nir_push_else(&b, NULL);
315            {
316               nir_trace_ray_commit_intel(&b);
317               nir_jump(&b, nir_jump_halt);
318            }
319            nir_pop_if(&b, NULL);
320            progress = true;
321            break;
322         }
323
324         case nir_intrinsic_terminate_ray: {
325            b.cursor = nir_instr_remove(&intrin->instr);
326            build_terminate_ray(&b);
327            progress = true;
328            break;
329         }
330
331         default:
332            break;
333         }
334      }
335   }
336
337   if (progress) {
338      nir_metadata_preserve(impl, nir_metadata_none);
339   } else {
340      nir_metadata_preserve(impl, nir_metadata_all);
341   }
342
343   return progress;
344}
345
346void
347brw_nir_lower_raygen(nir_shader *nir)
348{
349   assert(nir->info.stage == MESA_SHADER_RAYGEN);
350   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
351   lower_rt_io_and_scratch(nir);
352}
353
354void
355brw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo)
356{
357   assert(nir->info.stage == MESA_SHADER_ANY_HIT);
358   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
359   NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo);
360   lower_rt_io_and_scratch(nir);
361}
362
363void
364brw_nir_lower_closest_hit(nir_shader *nir)
365{
366   assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
367   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
368   lower_rt_io_and_scratch(nir);
369}
370
371void
372brw_nir_lower_miss(nir_shader *nir)
373{
374   assert(nir->info.stage == MESA_SHADER_MISS);
375   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
376   lower_rt_io_and_scratch(nir);
377}
378
379void
380brw_nir_lower_callable(nir_shader *nir)
381{
382   assert(nir->info.stage == MESA_SHADER_CALLABLE);
383   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
384   lower_rt_io_and_scratch(nir);
385}
386
387void
388brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
389                                            const nir_shader *any_hit,
390                                            const struct intel_device_info *devinfo)
391{
392   assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
393   assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
394   NIR_PASS_V(intersection, brw_nir_lower_shader_returns);
395   NIR_PASS_V(intersection, brw_nir_lower_intersection_shader,
396              any_hit, devinfo);
397   NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo);
398   lower_rt_io_and_scratch(intersection);
399}
400
401static nir_ssa_def *
402build_load_uniform(nir_builder *b, unsigned offset,
403                   unsigned num_components, unsigned bit_size)
404{
405   return nir_load_uniform(b, num_components, bit_size, nir_imm_int(b, 0),
406                           .base = offset,
407                           .range = num_components * bit_size / 8);
408}
409
410#define load_trampoline_param(b, name, num_components, bit_size) \
411   build_load_uniform((b), offsetof(struct brw_rt_raygen_trampoline_params, name), \
412                      (num_components), (bit_size))
413
414nir_shader *
415brw_nir_create_raygen_trampoline(const struct brw_compiler *compiler,
416                                 void *mem_ctx)
417{
418   const struct intel_device_info *devinfo = compiler->devinfo;
419   const nir_shader_compiler_options *nir_options =
420      compiler->glsl_compiler_options[MESA_SHADER_COMPUTE].NirOptions;
421
422   STATIC_ASSERT(sizeof(struct brw_rt_raygen_trampoline_params) == 32);
423
424   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE,
425                                                  nir_options,
426                                                  "RT Ray-Gen Trampoline");
427   ralloc_steal(mem_ctx, b.shader);
428
429   b.shader->info.workgroup_size_variable = true;
430
431   /* The RT global data and raygen BINDLESS_SHADER_RECORD addresses are
432    * passed in as push constants in the first register.  We deal with the
433    * raygen BSR address here; the global data we'll deal with later.
434    */
435   b.shader->num_uniforms = 32;
436   nir_ssa_def *raygen_bsr_addr =
437      load_trampoline_param(&b, raygen_bsr_addr, 1, 64);
438   nir_ssa_def *local_shift =
439      nir_u2u32(&b, load_trampoline_param(&b, local_group_size_log2, 3, 8));
440
441   nir_ssa_def *global_id = nir_load_workgroup_id(&b, 32);
442   nir_ssa_def *simd_channel = nir_load_subgroup_invocation(&b);
443   nir_ssa_def *local_x =
444      nir_ubfe(&b, simd_channel, nir_imm_int(&b, 0),
445                  nir_channel(&b, local_shift, 0));
446   nir_ssa_def *local_y =
447      nir_ubfe(&b, simd_channel, nir_channel(&b, local_shift, 0),
448                  nir_channel(&b, local_shift, 1));
449   nir_ssa_def *local_z =
450      nir_ubfe(&b, simd_channel,
451                  nir_iadd(&b, nir_channel(&b, local_shift, 0),
452                              nir_channel(&b, local_shift, 1)),
453                  nir_channel(&b, local_shift, 2));
454   nir_ssa_def *launch_id =
455      nir_iadd(&b, nir_ishl(&b, global_id, local_shift),
456                  nir_vec3(&b, local_x, local_y, local_z));
457
458   nir_ssa_def *launch_size = nir_load_ray_launch_size(&b);
459   nir_push_if(&b, nir_ball(&b, nir_ult(&b, launch_id, launch_size)));
460   {
461      nir_store_global(&b, brw_nir_rt_sw_hotzone_addr(&b, devinfo), 16,
462                       nir_vec4(&b, nir_imm_int(&b, 0), /* Stack ptr */
463                                    nir_channel(&b, launch_id, 0),
464                                    nir_channel(&b, launch_id, 1),
465                                    nir_channel(&b, launch_id, 2)),
466                       0xf /* write mask */);
467
468      brw_nir_btd_spawn(&b, raygen_bsr_addr);
469   }
470   nir_push_else(&b, NULL);
471   {
472      /* Even though these invocations aren't being used for anything, the
473       * hardware allocated stack IDs for them.  They need to retire them.
474       */
475      brw_nir_btd_retire(&b);
476   }
477   nir_pop_if(&b, NULL);
478
479   nir_shader *nir = b.shader;
480   nir->info.name = ralloc_strdup(nir, "RT: TraceRay trampoline");
481   nir_validate_shader(nir, "in brw_nir_create_raygen_trampoline");
482   brw_preprocess_nir(compiler, nir, NULL);
483
484   NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, devinfo);
485
486   /* brw_nir_lower_rt_intrinsics will leave us with a btd_global_arg_addr
487    * intrinsic which doesn't exist in compute shaders.  We also created one
488    * above when we generated the BTD spawn intrinsic.  Now we go through and
489    * replace them with a uniform load.
490    */
491   nir_foreach_block(block, b.impl) {
492      nir_foreach_instr_safe(instr, block) {
493         if (instr->type != nir_instr_type_intrinsic)
494            continue;
495
496         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
497         if (intrin->intrinsic != nir_intrinsic_load_btd_global_arg_addr_intel)
498            continue;
499
500         b.cursor = nir_before_instr(&intrin->instr);
501         nir_ssa_def *global_arg_addr =
502            load_trampoline_param(&b, rt_disp_globals_addr, 1, 64);
503         assert(intrin->dest.is_ssa);
504         nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
505                                  global_arg_addr);
506         nir_instr_remove(instr);
507      }
508   }
509
510   NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics);
511
512   brw_nir_optimize(nir, compiler, true, false);
513
514   return nir;
515}
516