17ec681f3Smrg/*
27ec681f3Smrg * Copyright © 2020 Intel Corporation
37ec681f3Smrg *
47ec681f3Smrg * Permission is hereby granted, free of charge, to any person obtaining a
57ec681f3Smrg * copy of this software and associated documentation files (the "Software"),
67ec681f3Smrg * to deal in the Software without restriction, including without limitation
77ec681f3Smrg * the rights to use, copy, modify, merge, publish, distribute, sublicense,
87ec681f3Smrg * and/or sell copies of the Software, and to permit persons to whom the
97ec681f3Smrg * Software is furnished to do so, subject to the following conditions:
107ec681f3Smrg *
117ec681f3Smrg * The above copyright notice and this permission notice (including the next
127ec681f3Smrg * paragraph) shall be included in all copies or substantial portions of the
137ec681f3Smrg * Software.
147ec681f3Smrg *
157ec681f3Smrg * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
167ec681f3Smrg * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
177ec681f3Smrg * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
187ec681f3Smrg * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
197ec681f3Smrg * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
207ec681f3Smrg * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
217ec681f3Smrg * IN THE SOFTWARE.
227ec681f3Smrg */
237ec681f3Smrg
247ec681f3Smrg#include "brw_nir_rt.h"
257ec681f3Smrg#include "brw_nir_rt_builder.h"
267ec681f3Smrg
277ec681f3Smrgstatic bool
287ec681f3Smrgresize_deref(nir_builder *b, nir_deref_instr *deref,
297ec681f3Smrg             unsigned num_components, unsigned bit_size)
307ec681f3Smrg{
317ec681f3Smrg   assert(deref->dest.is_ssa);
327ec681f3Smrg   if (deref->dest.ssa.num_components == num_components &&
337ec681f3Smrg       deref->dest.ssa.bit_size == bit_size)
347ec681f3Smrg      return false;
357ec681f3Smrg
367ec681f3Smrg   /* NIR requires array indices have to match the deref bit size */
377ec681f3Smrg   if (deref->dest.ssa.bit_size != bit_size &&
387ec681f3Smrg       (deref->deref_type == nir_deref_type_array ||
397ec681f3Smrg        deref->deref_type == nir_deref_type_ptr_as_array)) {
407ec681f3Smrg      b->cursor = nir_before_instr(&deref->instr);
417ec681f3Smrg      assert(deref->arr.index.is_ssa);
427ec681f3Smrg      nir_ssa_def *idx;
437ec681f3Smrg      if (nir_src_is_const(deref->arr.index)) {
447ec681f3Smrg         idx = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index), bit_size);
457ec681f3Smrg      } else {
467ec681f3Smrg         idx = nir_i2i(b, deref->arr.index.ssa, bit_size);
477ec681f3Smrg      }
487ec681f3Smrg      nir_instr_rewrite_src(&deref->instr, &deref->arr.index,
497ec681f3Smrg                            nir_src_for_ssa(idx));
507ec681f3Smrg   }
517ec681f3Smrg
527ec681f3Smrg   deref->dest.ssa.num_components = num_components;
537ec681f3Smrg   deref->dest.ssa.bit_size = bit_size;
547ec681f3Smrg
557ec681f3Smrg   return true;
567ec681f3Smrg}
577ec681f3Smrg
587ec681f3Smrgstatic bool
597ec681f3Smrglower_rt_io_derefs(nir_shader *shader)
607ec681f3Smrg{
617ec681f3Smrg   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
627ec681f3Smrg
637ec681f3Smrg   bool progress = false;
647ec681f3Smrg
657ec681f3Smrg   unsigned num_shader_call_vars = 0;
667ec681f3Smrg   nir_foreach_variable_with_modes(var, shader, nir_var_shader_call_data)
677ec681f3Smrg      num_shader_call_vars++;
687ec681f3Smrg
697ec681f3Smrg   unsigned num_ray_hit_attrib_vars = 0;
707ec681f3Smrg   nir_foreach_variable_with_modes(var, shader, nir_var_ray_hit_attrib)
717ec681f3Smrg      num_ray_hit_attrib_vars++;
727ec681f3Smrg
737ec681f3Smrg   /* At most one payload is allowed because it's an input.  Technically, this
747ec681f3Smrg    * is also true for hit attribute variables.  However, after we inline an
757ec681f3Smrg    * any-hit shader into an intersection shader, we can end up with multiple
767ec681f3Smrg    * hit attribute variables.  They'll end up mapping to a cast from the same
777ec681f3Smrg    * base pointer so this is fine.
787ec681f3Smrg    */
797ec681f3Smrg   assert(num_shader_call_vars <= 1);
807ec681f3Smrg
817ec681f3Smrg   nir_builder b;
827ec681f3Smrg   nir_builder_init(&b, impl);
837ec681f3Smrg
847ec681f3Smrg   b.cursor = nir_before_cf_list(&impl->body);
857ec681f3Smrg   nir_ssa_def *call_data_addr = NULL;
867ec681f3Smrg   if (num_shader_call_vars > 0) {
877ec681f3Smrg      assert(shader->scratch_size >= BRW_BTD_STACK_CALLEE_DATA_SIZE);
887ec681f3Smrg      call_data_addr =
897ec681f3Smrg         brw_nir_rt_load_scratch(&b, BRW_BTD_STACK_CALL_DATA_PTR_OFFSET, 8,
907ec681f3Smrg                                 1, 64);
917ec681f3Smrg      progress = true;
927ec681f3Smrg   }
937ec681f3Smrg
947ec681f3Smrg   gl_shader_stage stage = shader->info.stage;
957ec681f3Smrg   nir_ssa_def *hit_attrib_addr = NULL;
967ec681f3Smrg   if (num_ray_hit_attrib_vars > 0) {
977ec681f3Smrg      assert(stage == MESA_SHADER_ANY_HIT ||
987ec681f3Smrg             stage == MESA_SHADER_CLOSEST_HIT ||
997ec681f3Smrg             stage == MESA_SHADER_INTERSECTION);
1007ec681f3Smrg      nir_ssa_def *hit_addr =
1017ec681f3Smrg         brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT);
1027ec681f3Smrg      /* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */
1037ec681f3Smrg      nir_ssa_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4);
1047ec681f3Smrg      hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b),
1057ec681f3Smrg                                      brw_nir_rt_hit_attrib_data_addr(&b),
1067ec681f3Smrg                                      bary_addr);
1077ec681f3Smrg      progress = true;
1087ec681f3Smrg   }
1097ec681f3Smrg
1107ec681f3Smrg   nir_foreach_block(block, impl) {
1117ec681f3Smrg      nir_foreach_instr_safe(instr, block) {
1127ec681f3Smrg         if (instr->type != nir_instr_type_deref)
1137ec681f3Smrg            continue;
1147ec681f3Smrg
1157ec681f3Smrg         nir_deref_instr *deref = nir_instr_as_deref(instr);
1167ec681f3Smrg         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
1177ec681f3Smrg            deref->modes = nir_var_function_temp;
1187ec681f3Smrg            if (deref->deref_type == nir_deref_type_var) {
1197ec681f3Smrg               b.cursor = nir_before_instr(&deref->instr);
1207ec681f3Smrg               nir_deref_instr *cast =
1217ec681f3Smrg                  nir_build_deref_cast(&b, call_data_addr,
1227ec681f3Smrg                                       nir_var_function_temp,
1237ec681f3Smrg                                       deref->var->type, 0);
1247ec681f3Smrg               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
1257ec681f3Smrg                                        &cast->dest.ssa);
1267ec681f3Smrg               nir_instr_remove(&deref->instr);
1277ec681f3Smrg               progress = true;
1287ec681f3Smrg            }
1297ec681f3Smrg         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
1307ec681f3Smrg            deref->modes = nir_var_function_temp;
1317ec681f3Smrg            if (deref->deref_type == nir_deref_type_var) {
1327ec681f3Smrg               b.cursor = nir_before_instr(&deref->instr);
1337ec681f3Smrg               nir_deref_instr *cast =
1347ec681f3Smrg                  nir_build_deref_cast(&b, hit_attrib_addr,
1357ec681f3Smrg                                       nir_var_function_temp,
1367ec681f3Smrg                                       deref->type, 0);
1377ec681f3Smrg               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
1387ec681f3Smrg                                        &cast->dest.ssa);
1397ec681f3Smrg               nir_instr_remove(&deref->instr);
1407ec681f3Smrg               progress = true;
1417ec681f3Smrg            }
1427ec681f3Smrg         }
1437ec681f3Smrg
1447ec681f3Smrg         /* We're going to lower all function_temp memory to scratch using
1457ec681f3Smrg          * 64-bit addresses.  We need to resize all our derefs first or else
1467ec681f3Smrg          * nir_lower_explicit_io will have a fit.
1477ec681f3Smrg          */
1487ec681f3Smrg         if (nir_deref_mode_is(deref, nir_var_function_temp) &&
1497ec681f3Smrg             resize_deref(&b, deref, 1, 64))
1507ec681f3Smrg            progress = true;
1517ec681f3Smrg      }
1527ec681f3Smrg   }
1537ec681f3Smrg
1547ec681f3Smrg   if (progress) {
1557ec681f3Smrg      nir_metadata_preserve(impl, nir_metadata_block_index |
1567ec681f3Smrg                                  nir_metadata_dominance);
1577ec681f3Smrg   } else {
1587ec681f3Smrg      nir_metadata_preserve(impl, nir_metadata_all);
1597ec681f3Smrg   }
1607ec681f3Smrg
1617ec681f3Smrg   return progress;
1627ec681f3Smrg}
1637ec681f3Smrg
1647ec681f3Smrg/** Lowers ray-tracing shader I/O and scratch access
1657ec681f3Smrg *
1667ec681f3Smrg * SPV_KHR_ray_tracing adds three new types of I/O, each of which need their
1677ec681f3Smrg * own bit of special care:
1687ec681f3Smrg *
1697ec681f3Smrg *  - Shader payload data:  This is represented by the IncomingCallableData
1707ec681f3Smrg *    and IncomingRayPayload storage classes which are both represented by
1717ec681f3Smrg *    nir_var_call_data in NIR.  There is at most one of these per-shader and
1727ec681f3Smrg *    they contain payload data passed down the stack from the parent shader
1737ec681f3Smrg *    when it calls executeCallable() or traceRay().  In our implementation,
1747ec681f3Smrg *    the actual storage lives in the calling shader's scratch space and we're
1757ec681f3Smrg *    passed a pointer to it.
1767ec681f3Smrg *
1777ec681f3Smrg *  - Hit attribute data:  This is represented by the HitAttribute storage
1787ec681f3Smrg *    class in SPIR-V and nir_var_ray_hit_attrib in NIR.  For triangle
1797ec681f3Smrg *    geometry, it's supposed to contain two floats which are the barycentric
1807ec681f3Smrg *    coordinates.  For AABS/procedural geometry, it contains the hit data
1817ec681f3Smrg *    written out by the intersection shader.  In our implementation, it's a
1827ec681f3Smrg *    64-bit pointer which points either to the u/v area of the relevant
1837ec681f3Smrg *    MemHit data structure or the space right after the HW ray stack entry.
1847ec681f3Smrg *
1857ec681f3Smrg *  - Shader record buffer data:  This allows read-only access to the data
1867ec681f3Smrg *    stored in the SBT right after the bindless shader handles.  It's
1877ec681f3Smrg *    effectively a UBO with a magic address.  Coming out of spirv_to_nir,
1887ec681f3Smrg *    we get a nir_intrinsic_load_shader_record_ptr which is cast to a
1897ec681f3Smrg *    nir_var_mem_global deref and all access happens through that.  The
1907ec681f3Smrg *    shader_record_ptr system value is handled in brw_nir_lower_rt_intrinsics
1917ec681f3Smrg *    and we assume nir_lower_explicit_io is called elsewhere thanks to
1927ec681f3Smrg *    VK_KHR_buffer_device_address so there's really nothing to do here.
1937ec681f3Smrg *
1947ec681f3Smrg * We also handle lowering any remaining function_temp variables to scratch at
1957ec681f3Smrg * this point.  This gets rid of any remaining arrays and also takes care of
1967ec681f3Smrg * the sending side of ray payloads where we pass pointers to a function_temp
1977ec681f3Smrg * variable down the call stack.
1987ec681f3Smrg */
1997ec681f3Smrgstatic void
2007ec681f3Smrglower_rt_io_and_scratch(nir_shader *nir)
2017ec681f3Smrg{
2027ec681f3Smrg   /* First, we to ensure all the I/O variables have explicit types.  Because
2037ec681f3Smrg    * these are shader-internal and don't come in from outside, they don't
2047ec681f3Smrg    * have an explicit memory layout and we have to assign them one.
2057ec681f3Smrg    */
2067ec681f3Smrg   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
2077ec681f3Smrg              nir_var_function_temp |
2087ec681f3Smrg              nir_var_shader_call_data |
2097ec681f3Smrg              nir_var_ray_hit_attrib,
2107ec681f3Smrg              glsl_get_natural_size_align_bytes);
2117ec681f3Smrg
2127ec681f3Smrg   /* Now patch any derefs to I/O vars */
2137ec681f3Smrg   NIR_PASS_V(nir, lower_rt_io_derefs);
2147ec681f3Smrg
2157ec681f3Smrg   /* Finally, lower any remaining function_temp, mem_constant, or
2167ec681f3Smrg    * ray_hit_attrib access to 64-bit global memory access.
2177ec681f3Smrg    */
2187ec681f3Smrg   NIR_PASS_V(nir, nir_lower_explicit_io,
2197ec681f3Smrg              nir_var_function_temp |
2207ec681f3Smrg              nir_var_mem_constant |
2217ec681f3Smrg              nir_var_ray_hit_attrib,
2227ec681f3Smrg              nir_address_format_64bit_global);
2237ec681f3Smrg}
2247ec681f3Smrg
2257ec681f3Smrgstatic void
2267ec681f3Smrgbuild_terminate_ray(nir_builder *b)
2277ec681f3Smrg{
2287ec681f3Smrg   nir_ssa_def *skip_closest_hit =
2297ec681f3Smrg      nir_i2b(b, nir_iand_imm(b, nir_load_ray_flags(b),
2307ec681f3Smrg                              BRW_RT_RAY_FLAG_SKIP_CLOSEST_HIT_SHADER));
2317ec681f3Smrg   nir_push_if(b, skip_closest_hit);
2327ec681f3Smrg   {
2337ec681f3Smrg      /* The shader that calls traceRay() is unable to access any ray hit
2347ec681f3Smrg       * information except for that which is explicitly written into the ray
2357ec681f3Smrg       * payload by shaders invoked during the trace.  If there's no closest-
2367ec681f3Smrg       * hit shader, then accepting the hit has no observable effect; it's
2377ec681f3Smrg       * just extra memory traffic for no reason.
2387ec681f3Smrg       */
2397ec681f3Smrg      brw_nir_btd_return(b);
2407ec681f3Smrg      nir_jump(b, nir_jump_halt);
2417ec681f3Smrg   }
2427ec681f3Smrg   nir_push_else(b, NULL);
2437ec681f3Smrg   {
2447ec681f3Smrg      /* The closest hit shader is in the same shader group as the any-hit
2457ec681f3Smrg       * shader that we're currently in.  We can get the address for its SBT
2467ec681f3Smrg       * handle by looking at the shader record pointer and subtracting the
2477ec681f3Smrg       * size of a SBT handle.  The BINDLESS_SHADER_RECORD for a closest hit
2487ec681f3Smrg       * shader is the first one in the SBT handle.
2497ec681f3Smrg       */
2507ec681f3Smrg      nir_ssa_def *closest_hit =
2517ec681f3Smrg         nir_iadd_imm(b, nir_load_shader_record_ptr(b),
2527ec681f3Smrg                        -BRW_RT_SBT_HANDLE_SIZE);
2537ec681f3Smrg
2547ec681f3Smrg      brw_nir_rt_commit_hit(b);
2557ec681f3Smrg      brw_nir_btd_spawn(b, closest_hit);
2567ec681f3Smrg      nir_jump(b, nir_jump_halt);
2577ec681f3Smrg   }
2587ec681f3Smrg   nir_pop_if(b, NULL);
2597ec681f3Smrg}
2607ec681f3Smrg
2617ec681f3Smrg/** Lowers away ray walk intrinsics
2627ec681f3Smrg *
2637ec681f3Smrg * This lowers terminate_ray, ignore_ray_intersection, and the NIR-specific
2647ec681f3Smrg * accept_ray_intersection intrinsics to the appropriate Intel-specific
2657ec681f3Smrg * intrinsics.
2667ec681f3Smrg */
2677ec681f3Smrgstatic bool
2687ec681f3Smrglower_ray_walk_intrinsics(nir_shader *shader,
2697ec681f3Smrg                          const struct intel_device_info *devinfo)
2707ec681f3Smrg{
2717ec681f3Smrg   assert(shader->info.stage == MESA_SHADER_ANY_HIT ||
2727ec681f3Smrg          shader->info.stage == MESA_SHADER_INTERSECTION);
2737ec681f3Smrg
2747ec681f3Smrg   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
2757ec681f3Smrg
2767ec681f3Smrg   nir_builder b;
2777ec681f3Smrg   nir_builder_init(&b, impl);
2787ec681f3Smrg
2797ec681f3Smrg   bool progress = false;
2807ec681f3Smrg   nir_foreach_block_safe(block, impl) {
2817ec681f3Smrg      nir_foreach_instr_safe(instr, block) {
2827ec681f3Smrg         if (instr->type != nir_instr_type_intrinsic)
2837ec681f3Smrg            continue;
2847ec681f3Smrg
2857ec681f3Smrg         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2867ec681f3Smrg
2877ec681f3Smrg         switch (intrin->intrinsic) {
2887ec681f3Smrg         case nir_intrinsic_ignore_ray_intersection: {
2897ec681f3Smrg            b.cursor = nir_instr_remove(&intrin->instr);
2907ec681f3Smrg
2917ec681f3Smrg            /* We put the newly emitted code inside a dummy if because it's
2927ec681f3Smrg             * going to contain a jump instruction and we don't want to deal
2937ec681f3Smrg             * with that mess here.  It'll get dealt with by our control-flow
2947ec681f3Smrg             * optimization passes.
2957ec681f3Smrg             */
2967ec681f3Smrg            nir_push_if(&b, nir_imm_true(&b));
2977ec681f3Smrg            nir_trace_ray_continue_intel(&b);
2987ec681f3Smrg            nir_jump(&b, nir_jump_halt);
2997ec681f3Smrg            nir_pop_if(&b, NULL);
3007ec681f3Smrg            progress = true;
3017ec681f3Smrg            break;
3027ec681f3Smrg         }
3037ec681f3Smrg
3047ec681f3Smrg         case nir_intrinsic_accept_ray_intersection: {
3057ec681f3Smrg            b.cursor = nir_instr_remove(&intrin->instr);
3067ec681f3Smrg
3077ec681f3Smrg            nir_ssa_def *terminate =
3087ec681f3Smrg               nir_i2b(&b, nir_iand_imm(&b, nir_load_ray_flags(&b),
3097ec681f3Smrg                                        BRW_RT_RAY_FLAG_TERMINATE_ON_FIRST_HIT));
3107ec681f3Smrg            nir_push_if(&b, terminate);
3117ec681f3Smrg            {
3127ec681f3Smrg               build_terminate_ray(&b);
3137ec681f3Smrg            }
3147ec681f3Smrg            nir_push_else(&b, NULL);
3157ec681f3Smrg            {
3167ec681f3Smrg               nir_trace_ray_commit_intel(&b);
3177ec681f3Smrg               nir_jump(&b, nir_jump_halt);
3187ec681f3Smrg            }
3197ec681f3Smrg            nir_pop_if(&b, NULL);
3207ec681f3Smrg            progress = true;
3217ec681f3Smrg            break;
3227ec681f3Smrg         }
3237ec681f3Smrg
3247ec681f3Smrg         case nir_intrinsic_terminate_ray: {
3257ec681f3Smrg            b.cursor = nir_instr_remove(&intrin->instr);
3267ec681f3Smrg            build_terminate_ray(&b);
3277ec681f3Smrg            progress = true;
3287ec681f3Smrg            break;
3297ec681f3Smrg         }
3307ec681f3Smrg
3317ec681f3Smrg         default:
3327ec681f3Smrg            break;
3337ec681f3Smrg         }
3347ec681f3Smrg      }
3357ec681f3Smrg   }
3367ec681f3Smrg
3377ec681f3Smrg   if (progress) {
3387ec681f3Smrg      nir_metadata_preserve(impl, nir_metadata_none);
3397ec681f3Smrg   } else {
3407ec681f3Smrg      nir_metadata_preserve(impl, nir_metadata_all);
3417ec681f3Smrg   }
3427ec681f3Smrg
3437ec681f3Smrg   return progress;
3447ec681f3Smrg}
3457ec681f3Smrg
3467ec681f3Smrgvoid
3477ec681f3Smrgbrw_nir_lower_raygen(nir_shader *nir)
3487ec681f3Smrg{
3497ec681f3Smrg   assert(nir->info.stage == MESA_SHADER_RAYGEN);
3507ec681f3Smrg   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
3517ec681f3Smrg   lower_rt_io_and_scratch(nir);
3527ec681f3Smrg}
3537ec681f3Smrg
3547ec681f3Smrgvoid
3557ec681f3Smrgbrw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo)
3567ec681f3Smrg{
3577ec681f3Smrg   assert(nir->info.stage == MESA_SHADER_ANY_HIT);
3587ec681f3Smrg   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
3597ec681f3Smrg   NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo);
3607ec681f3Smrg   lower_rt_io_and_scratch(nir);
3617ec681f3Smrg}
3627ec681f3Smrg
3637ec681f3Smrgvoid
3647ec681f3Smrgbrw_nir_lower_closest_hit(nir_shader *nir)
3657ec681f3Smrg{
3667ec681f3Smrg   assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
3677ec681f3Smrg   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
3687ec681f3Smrg   lower_rt_io_and_scratch(nir);
3697ec681f3Smrg}
3707ec681f3Smrg
3717ec681f3Smrgvoid
3727ec681f3Smrgbrw_nir_lower_miss(nir_shader *nir)
3737ec681f3Smrg{
3747ec681f3Smrg   assert(nir->info.stage == MESA_SHADER_MISS);
3757ec681f3Smrg   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
3767ec681f3Smrg   lower_rt_io_and_scratch(nir);
3777ec681f3Smrg}
3787ec681f3Smrg
3797ec681f3Smrgvoid
3807ec681f3Smrgbrw_nir_lower_callable(nir_shader *nir)
3817ec681f3Smrg{
3827ec681f3Smrg   assert(nir->info.stage == MESA_SHADER_CALLABLE);
3837ec681f3Smrg   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
3847ec681f3Smrg   lower_rt_io_and_scratch(nir);
3857ec681f3Smrg}
3867ec681f3Smrg
3877ec681f3Smrgvoid
3887ec681f3Smrgbrw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
3897ec681f3Smrg                                            const nir_shader *any_hit,
3907ec681f3Smrg                                            const struct intel_device_info *devinfo)
3917ec681f3Smrg{
3927ec681f3Smrg   assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
3937ec681f3Smrg   assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
3947ec681f3Smrg   NIR_PASS_V(intersection, brw_nir_lower_shader_returns);
3957ec681f3Smrg   NIR_PASS_V(intersection, brw_nir_lower_intersection_shader,
3967ec681f3Smrg              any_hit, devinfo);
3977ec681f3Smrg   NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo);
3987ec681f3Smrg   lower_rt_io_and_scratch(intersection);
3997ec681f3Smrg}
4007ec681f3Smrg
4017ec681f3Smrgstatic nir_ssa_def *
4027ec681f3Smrgbuild_load_uniform(nir_builder *b, unsigned offset,
4037ec681f3Smrg                   unsigned num_components, unsigned bit_size)
4047ec681f3Smrg{
4057ec681f3Smrg   return nir_load_uniform(b, num_components, bit_size, nir_imm_int(b, 0),
4067ec681f3Smrg                           .base = offset,
4077ec681f3Smrg                           .range = num_components * bit_size / 8);
4087ec681f3Smrg}
4097ec681f3Smrg
4107ec681f3Smrg#define load_trampoline_param(b, name, num_components, bit_size) \
4117ec681f3Smrg   build_load_uniform((b), offsetof(struct brw_rt_raygen_trampoline_params, name), \
4127ec681f3Smrg                      (num_components), (bit_size))
4137ec681f3Smrg
4147ec681f3Smrgnir_shader *
4157ec681f3Smrgbrw_nir_create_raygen_trampoline(const struct brw_compiler *compiler,
4167ec681f3Smrg                                 void *mem_ctx)
4177ec681f3Smrg{
4187ec681f3Smrg   const struct intel_device_info *devinfo = compiler->devinfo;
4197ec681f3Smrg   const nir_shader_compiler_options *nir_options =
4207ec681f3Smrg      compiler->glsl_compiler_options[MESA_SHADER_COMPUTE].NirOptions;
4217ec681f3Smrg
4227ec681f3Smrg   STATIC_ASSERT(sizeof(struct brw_rt_raygen_trampoline_params) == 32);
4237ec681f3Smrg
4247ec681f3Smrg   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE,
4257ec681f3Smrg                                                  nir_options,
4267ec681f3Smrg                                                  "RT Ray-Gen Trampoline");
4277ec681f3Smrg   ralloc_steal(mem_ctx, b.shader);
4287ec681f3Smrg
4297ec681f3Smrg   b.shader->info.workgroup_size_variable = true;
4307ec681f3Smrg
4317ec681f3Smrg   /* The RT global data and raygen BINDLESS_SHADER_RECORD addresses are
4327ec681f3Smrg    * passed in as push constants in the first register.  We deal with the
4337ec681f3Smrg    * raygen BSR address here; the global data we'll deal with later.
4347ec681f3Smrg    */
4357ec681f3Smrg   b.shader->num_uniforms = 32;
4367ec681f3Smrg   nir_ssa_def *raygen_bsr_addr =
4377ec681f3Smrg      load_trampoline_param(&b, raygen_bsr_addr, 1, 64);
4387ec681f3Smrg   nir_ssa_def *local_shift =
4397ec681f3Smrg      nir_u2u32(&b, load_trampoline_param(&b, local_group_size_log2, 3, 8));
4407ec681f3Smrg
4417ec681f3Smrg   nir_ssa_def *global_id = nir_load_workgroup_id(&b, 32);
4427ec681f3Smrg   nir_ssa_def *simd_channel = nir_load_subgroup_invocation(&b);
4437ec681f3Smrg   nir_ssa_def *local_x =
4447ec681f3Smrg      nir_ubfe(&b, simd_channel, nir_imm_int(&b, 0),
4457ec681f3Smrg                  nir_channel(&b, local_shift, 0));
4467ec681f3Smrg   nir_ssa_def *local_y =
4477ec681f3Smrg      nir_ubfe(&b, simd_channel, nir_channel(&b, local_shift, 0),
4487ec681f3Smrg                  nir_channel(&b, local_shift, 1));
4497ec681f3Smrg   nir_ssa_def *local_z =
4507ec681f3Smrg      nir_ubfe(&b, simd_channel,
4517ec681f3Smrg                  nir_iadd(&b, nir_channel(&b, local_shift, 0),
4527ec681f3Smrg                              nir_channel(&b, local_shift, 1)),
4537ec681f3Smrg                  nir_channel(&b, local_shift, 2));
4547ec681f3Smrg   nir_ssa_def *launch_id =
4557ec681f3Smrg      nir_iadd(&b, nir_ishl(&b, global_id, local_shift),
4567ec681f3Smrg                  nir_vec3(&b, local_x, local_y, local_z));
4577ec681f3Smrg
4587ec681f3Smrg   nir_ssa_def *launch_size = nir_load_ray_launch_size(&b);
4597ec681f3Smrg   nir_push_if(&b, nir_ball(&b, nir_ult(&b, launch_id, launch_size)));
4607ec681f3Smrg   {
4617ec681f3Smrg      nir_store_global(&b, brw_nir_rt_sw_hotzone_addr(&b, devinfo), 16,
4627ec681f3Smrg                       nir_vec4(&b, nir_imm_int(&b, 0), /* Stack ptr */
4637ec681f3Smrg                                    nir_channel(&b, launch_id, 0),
4647ec681f3Smrg                                    nir_channel(&b, launch_id, 1),
4657ec681f3Smrg                                    nir_channel(&b, launch_id, 2)),
4667ec681f3Smrg                       0xf /* write mask */);
4677ec681f3Smrg
4687ec681f3Smrg      brw_nir_btd_spawn(&b, raygen_bsr_addr);
4697ec681f3Smrg   }
4707ec681f3Smrg   nir_push_else(&b, NULL);
4717ec681f3Smrg   {
4727ec681f3Smrg      /* Even though these invocations aren't being used for anything, the
4737ec681f3Smrg       * hardware allocated stack IDs for them.  They need to retire them.
4747ec681f3Smrg       */
4757ec681f3Smrg      brw_nir_btd_retire(&b);
4767ec681f3Smrg   }
4777ec681f3Smrg   nir_pop_if(&b, NULL);
4787ec681f3Smrg
4797ec681f3Smrg   nir_shader *nir = b.shader;
4807ec681f3Smrg   nir->info.name = ralloc_strdup(nir, "RT: TraceRay trampoline");
4817ec681f3Smrg   nir_validate_shader(nir, "in brw_nir_create_raygen_trampoline");
4827ec681f3Smrg   brw_preprocess_nir(compiler, nir, NULL);
4837ec681f3Smrg
4847ec681f3Smrg   NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, devinfo);
4857ec681f3Smrg
4867ec681f3Smrg   /* brw_nir_lower_rt_intrinsics will leave us with a btd_global_arg_addr
4877ec681f3Smrg    * intrinsic which doesn't exist in compute shaders.  We also created one
4887ec681f3Smrg    * above when we generated the BTD spawn intrinsic.  Now we go through and
4897ec681f3Smrg    * replace them with a uniform load.
4907ec681f3Smrg    */
4917ec681f3Smrg   nir_foreach_block(block, b.impl) {
4927ec681f3Smrg      nir_foreach_instr_safe(instr, block) {
4937ec681f3Smrg         if (instr->type != nir_instr_type_intrinsic)
4947ec681f3Smrg            continue;
4957ec681f3Smrg
4967ec681f3Smrg         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
4977ec681f3Smrg         if (intrin->intrinsic != nir_intrinsic_load_btd_global_arg_addr_intel)
4987ec681f3Smrg            continue;
4997ec681f3Smrg
5007ec681f3Smrg         b.cursor = nir_before_instr(&intrin->instr);
5017ec681f3Smrg         nir_ssa_def *global_arg_addr =
5027ec681f3Smrg            load_trampoline_param(&b, rt_disp_globals_addr, 1, 64);
5037ec681f3Smrg         assert(intrin->dest.is_ssa);
5047ec681f3Smrg         nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
5057ec681f3Smrg                                  global_arg_addr);
5067ec681f3Smrg         nir_instr_remove(instr);
5077ec681f3Smrg      }
5087ec681f3Smrg   }
5097ec681f3Smrg
5107ec681f3Smrg   NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics);
5117ec681f3Smrg
5127ec681f3Smrg   brw_nir_optimize(nir, compiler, true, false);
5137ec681f3Smrg
5147ec681f3Smrg   return nir;
5157ec681f3Smrg}
516