1/*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#ifndef BRW_NIR_RT_BUILDER_H
25#define BRW_NIR_RT_BUILDER_H
26
27#include "brw_rt.h"
28#include "nir_builder.h"
29
30/* We have our own load/store scratch helpers because they emit a global
31 * memory read or write based on the scratch_base_ptr system value rather
32 * than a load/store_scratch intrinsic.
33 */
34static inline nir_ssa_def *
35brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
36                        unsigned num_components, unsigned bit_size)
37{
38   nir_ssa_def *addr =
39      nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
40   return nir_load_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
41                          num_components, bit_size);
42}
43
44static inline void
45brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
46                         nir_ssa_def *value, nir_component_mask_t write_mask)
47{
48   nir_ssa_def *addr =
49      nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
50   nir_store_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
51                    value, write_mask);
52}
53
54static inline void
55brw_nir_btd_spawn(nir_builder *b, nir_ssa_def *record_addr)
56{
57   nir_btd_spawn_intel(b, nir_load_btd_global_arg_addr_intel(b), record_addr);
58}
59
60static inline void
61brw_nir_btd_retire(nir_builder *b)
62{
63   nir_btd_retire_intel(b);
64}
65
66/** This is a pseudo-op which does a bindless return
67 *
68 * It loads the return address from the stack and calls btd_spawn to spawn the
69 * resume shader.
70 */
71static inline void
72brw_nir_btd_return(struct nir_builder *b)
73{
74   assert(b->shader->scratch_size == BRW_BTD_STACK_CALLEE_DATA_SIZE);
75   nir_ssa_def *resume_addr =
76      brw_nir_rt_load_scratch(b, BRW_BTD_STACK_RESUME_BSR_ADDR_OFFSET,
77                              8 /* align */, 1, 64);
78   brw_nir_btd_spawn(b, resume_addr);
79}
80
81static inline void
82assert_def_size(nir_ssa_def *def, unsigned num_components, unsigned bit_size)
83{
84   assert(def->num_components == num_components);
85   assert(def->bit_size == bit_size);
86}
87
88static inline nir_ssa_def *
89brw_nir_num_rt_stacks(nir_builder *b,
90                      const struct intel_device_info *devinfo)
91{
92   return nir_imul_imm(b, nir_load_ray_num_dss_rt_stacks_intel(b),
93                          intel_device_info_num_dual_subslices(devinfo));
94}
95
96static inline nir_ssa_def *
97brw_nir_rt_stack_id(nir_builder *b)
98{
99   return nir_iadd(b, nir_umul_32x16(b, nir_load_ray_num_dss_rt_stacks_intel(b),
100                                     nir_load_btd_dss_id_intel(b)),
101                      nir_load_btd_stack_id_intel(b));
102}
103
104static inline nir_ssa_def *
105brw_nir_rt_sw_hotzone_addr(nir_builder *b,
106                           const struct intel_device_info *devinfo)
107{
108   nir_ssa_def *offset32 =
109      nir_imul_imm(b, brw_nir_rt_stack_id(b), BRW_RT_SIZEOF_HOTZONE);
110
111   offset32 = nir_iadd(b, offset32, nir_ineg(b,
112      nir_imul_imm(b, brw_nir_num_rt_stacks(b, devinfo),
113                      BRW_RT_SIZEOF_HOTZONE)));
114
115   return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
116                      nir_i2i64(b, offset32));
117}
118
119static inline nir_ssa_def *
120brw_nir_rt_ray_addr(nir_builder *b)
121{
122   /* From the BSpec "Address Computation for Memory Based Data Structures:
123    * Ray and TraversalStack (Async Ray Tracing)":
124    *
125    *    stackBase = RTDispatchGlobals.rtMemBasePtr
126    *              + (DSSID * RTDispatchGlobals.numDSSRTStacks + stackID)
127    *              * RTDispatchGlobals.stackSizePerRay // 64B aligned
128    *
129    * We assume that we can calculate a 32-bit offset first and then add it
130    * to the 64-bit base address at the end.
131    */
132   nir_ssa_def *offset32 =
133      nir_imul(b, brw_nir_rt_stack_id(b),
134                  nir_load_ray_hw_stack_size_intel(b));
135   return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
136                      nir_u2u64(b, offset32));
137}
138
139static inline nir_ssa_def *
140brw_nir_rt_mem_hit_addr(nir_builder *b, bool committed)
141{
142   return nir_iadd_imm(b, brw_nir_rt_ray_addr(b),
143                          committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
144}
145
146static inline nir_ssa_def *
147brw_nir_rt_hit_attrib_data_addr(nir_builder *b)
148{
149   return nir_iadd_imm(b, brw_nir_rt_ray_addr(b),
150                          BRW_RT_OFFSETOF_HIT_ATTRIB_DATA);
151}
152
153static inline nir_ssa_def *
154brw_nir_rt_mem_ray_addr(nir_builder *b,
155                        enum brw_rt_bvh_level bvh_level)
156{
157   /* From the BSpec "Address Computation for Memory Based Data Structures:
158    * Ray and TraversalStack (Async Ray Tracing)":
159    *
160    *    rayBase = stackBase + sizeof(HitInfo) * 2 // 64B aligned
161    *    rayPtr  = rayBase + bvhLevel * sizeof(Ray); // 64B aligned
162    *
163    * In Vulkan, we always have exactly two levels of BVH: World and Object.
164    */
165   uint32_t offset = BRW_RT_SIZEOF_HIT_INFO * 2 +
166                     bvh_level * BRW_RT_SIZEOF_RAY;
167   return nir_iadd_imm(b, brw_nir_rt_ray_addr(b), offset);
168}
169
170static inline nir_ssa_def *
171brw_nir_rt_sw_stack_addr(nir_builder *b,
172                         const struct intel_device_info *devinfo)
173{
174   nir_ssa_def *addr = nir_load_ray_base_mem_addr_intel(b);
175
176   nir_ssa_def *offset32 = nir_imul(b, brw_nir_num_rt_stacks(b, devinfo),
177                                       nir_load_ray_hw_stack_size_intel(b));
178   addr = nir_iadd(b, addr, nir_u2u64(b, offset32));
179
180   return nir_iadd(b, addr,
181      nir_imul(b, nir_u2u64(b, brw_nir_rt_stack_id(b)),
182                  nir_u2u64(b, nir_load_ray_sw_stack_size_intel(b))));
183}
184
185static inline nir_ssa_def *
186nir_unpack_64_4x16_split_z(nir_builder *b, nir_ssa_def *val)
187{
188   return nir_unpack_32_2x16_split_x(b, nir_unpack_64_2x32_split_y(b, val));
189}
190
191struct brw_nir_rt_globals_defs {
192   nir_ssa_def *base_mem_addr;
193   nir_ssa_def *call_stack_handler_addr;
194   nir_ssa_def *hw_stack_size;
195   nir_ssa_def *num_dss_rt_stacks;
196   nir_ssa_def *hit_sbt_addr;
197   nir_ssa_def *hit_sbt_stride;
198   nir_ssa_def *miss_sbt_addr;
199   nir_ssa_def *miss_sbt_stride;
200   nir_ssa_def *sw_stack_size;
201   nir_ssa_def *launch_size;
202   nir_ssa_def *call_sbt_addr;
203   nir_ssa_def *call_sbt_stride;
204   nir_ssa_def *resume_sbt_addr;
205};
206
207static inline void
208brw_nir_rt_load_globals(nir_builder *b,
209                        struct brw_nir_rt_globals_defs *defs)
210{
211   nir_ssa_def *addr = nir_load_btd_global_arg_addr_intel(b);
212
213   nir_ssa_def *data;
214   data = nir_load_global_const_block_intel(b, 16, addr, nir_imm_true(b));
215   defs->base_mem_addr = nir_pack_64_2x32(b, nir_channels(b, data, 0x3));
216
217   defs->call_stack_handler_addr =
218      nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
219
220   defs->hw_stack_size = nir_channel(b, data, 4);
221   defs->num_dss_rt_stacks = nir_iand_imm(b, nir_channel(b, data, 5), 0xffff);
222   defs->hit_sbt_addr =
223      nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
224                                nir_extract_i16(b, nir_channel(b, data, 9),
225                                                   nir_imm_int(b, 0)));
226   defs->hit_sbt_stride =
227      nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
228   defs->miss_sbt_addr =
229      nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
230                                nir_extract_i16(b, nir_channel(b, data, 11),
231                                                   nir_imm_int(b, 0)));
232   defs->miss_sbt_stride =
233      nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
234   defs->sw_stack_size = nir_channel(b, data, 12);
235   defs->launch_size = nir_channels(b, data, 0x7u << 13);
236
237   data = nir_load_global_const_block_intel(b, 8, nir_iadd_imm(b, addr, 64),
238                                                  nir_imm_true(b));
239   defs->call_sbt_addr =
240      nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
241                                nir_extract_i16(b, nir_channel(b, data, 1),
242                                                   nir_imm_int(b, 0)));
243   defs->call_sbt_stride =
244      nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
245
246   defs->resume_sbt_addr =
247      nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
248}
249
250static inline nir_ssa_def *
251brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_ssa_def *vec2)
252{
253   /* Hit record leaf pointers are 42-bit and assumed to be in 64B chunks.
254    * This leaves 22 bits at the top for other stuff.
255    */
256   nir_ssa_def *ptr64 = nir_imul_imm(b, nir_pack_64_2x32(b, vec2), 64);
257
258   /* The top 16 bits (remember, we shifted by 6 already) contain garbage
259    * that we need to get rid of.
260    */
261   nir_ssa_def *ptr_lo = nir_unpack_64_2x32_split_x(b, ptr64);
262   nir_ssa_def *ptr_hi = nir_unpack_64_2x32_split_y(b, ptr64);
263   ptr_hi = nir_extract_i16(b, ptr_hi, nir_imm_int(b, 0));
264   return nir_pack_64_2x32_split(b, ptr_lo, ptr_hi);
265}
266
267struct brw_nir_rt_mem_hit_defs {
268   nir_ssa_def *t;
269   nir_ssa_def *tri_bary; /**< Only valid for triangle geometry */
270   nir_ssa_def *aabb_hit_kind; /**< Only valid for AABB geometry */
271   nir_ssa_def *leaf_type;
272   nir_ssa_def *prim_leaf_index;
273   nir_ssa_def *front_face;
274   nir_ssa_def *prim_leaf_ptr;
275   nir_ssa_def *inst_leaf_ptr;
276};
277
278static inline void
279brw_nir_rt_load_mem_hit(nir_builder *b,
280                        struct brw_nir_rt_mem_hit_defs *defs,
281                        bool committed)
282{
283   nir_ssa_def *hit_addr = brw_nir_rt_mem_hit_addr(b, committed);
284
285   nir_ssa_def *data = nir_load_global(b, hit_addr, 16, 4, 32);
286   defs->t = nir_channel(b, data, 0);
287   defs->aabb_hit_kind = nir_channel(b, data, 1);
288   defs->tri_bary = nir_channels(b, data, 0x6);
289   nir_ssa_def *bitfield = nir_channel(b, data, 3);
290   defs->leaf_type =
291      nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 17), nir_imm_int(b, 3));
292   defs->prim_leaf_index =
293      nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 20), nir_imm_int(b, 4));
294   defs->front_face = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 27));
295
296   data = nir_load_global(b, nir_iadd_imm(b, hit_addr, 16), 16, 4, 32);
297   defs->prim_leaf_ptr =
298      brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 0));
299   defs->inst_leaf_ptr =
300      brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 2));
301}
302
303static inline void
304brw_nir_memcpy_global(nir_builder *b,
305                      nir_ssa_def *dst_addr, uint32_t dst_align,
306                      nir_ssa_def *src_addr, uint32_t src_align,
307                      uint32_t size)
308{
309   /* We're going to copy in 16B chunks */
310   assert(size % 16 == 0);
311   dst_align = MIN2(dst_align, 16);
312   src_align = MIN2(src_align, 16);
313
314   for (unsigned offset = 0; offset < size; offset += 16) {
315      nir_ssa_def *data =
316         nir_load_global(b, nir_iadd_imm(b, src_addr, offset), src_align,
317                         4, 32);
318      nir_store_global(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
319                       data, 0xf /* write_mask */);
320   }
321}
322
323static inline void
324brw_nir_rt_commit_hit(nir_builder *b)
325{
326   brw_nir_memcpy_global(b, brw_nir_rt_mem_hit_addr(b, true), 16,
327                            brw_nir_rt_mem_hit_addr(b, false), 16,
328                            BRW_RT_SIZEOF_HIT_INFO);
329}
330
331struct brw_nir_rt_mem_ray_defs {
332   nir_ssa_def *orig;
333   nir_ssa_def *dir;
334   nir_ssa_def *t_near;
335   nir_ssa_def *t_far;
336   nir_ssa_def *root_node_ptr;
337   nir_ssa_def *ray_flags;
338   nir_ssa_def *hit_group_sr_base_ptr;
339   nir_ssa_def *hit_group_sr_stride;
340   nir_ssa_def *miss_sr_ptr;
341   nir_ssa_def *shader_index_multiplier;
342   nir_ssa_def *inst_leaf_ptr;
343   nir_ssa_def *ray_mask;
344};
345
346static inline void
347brw_nir_rt_store_mem_ray(nir_builder *b,
348                         const struct brw_nir_rt_mem_ray_defs *defs,
349                         enum brw_rt_bvh_level bvh_level)
350{
351   nir_ssa_def *ray_addr = brw_nir_rt_mem_ray_addr(b, bvh_level);
352
353   assert_def_size(defs->orig, 3, 32);
354   assert_def_size(defs->dir, 3, 32);
355   nir_store_global(b, nir_iadd_imm(b, ray_addr, 0), 16,
356      nir_vec4(b, nir_channel(b, defs->orig, 0),
357                  nir_channel(b, defs->orig, 1),
358                  nir_channel(b, defs->orig, 2),
359                  nir_channel(b, defs->dir, 0)),
360      ~0 /* write mask */);
361
362   assert_def_size(defs->t_near, 1, 32);
363   assert_def_size(defs->t_far, 1, 32);
364   nir_store_global(b, nir_iadd_imm(b, ray_addr, 16), 16,
365      nir_vec4(b, nir_channel(b, defs->dir, 1),
366                  nir_channel(b, defs->dir, 2),
367                  defs->t_near,
368                  defs->t_far),
369      ~0 /* write mask */);
370
371   assert_def_size(defs->root_node_ptr, 1, 64);
372   assert_def_size(defs->ray_flags, 1, 16);
373   assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
374   assert_def_size(defs->hit_group_sr_stride, 1, 16);
375   nir_store_global(b, nir_iadd_imm(b, ray_addr, 32), 16,
376      nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
377                  nir_pack_32_2x16_split(b,
378                     nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
379                     defs->ray_flags),
380                  nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
381                  nir_pack_32_2x16_split(b,
382                     nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
383                     defs->hit_group_sr_stride)),
384      ~0 /* write mask */);
385
386   /* leaf_ptr is optional */
387   nir_ssa_def *inst_leaf_ptr;
388   if (defs->inst_leaf_ptr) {
389      inst_leaf_ptr = defs->inst_leaf_ptr;
390   } else {
391      inst_leaf_ptr = nir_imm_int64(b, 0);
392   }
393
394   assert_def_size(defs->miss_sr_ptr, 1, 64);
395   assert_def_size(defs->shader_index_multiplier, 1, 32);
396   assert_def_size(inst_leaf_ptr, 1, 64);
397   assert_def_size(defs->ray_mask, 1, 32);
398   nir_store_global(b, nir_iadd_imm(b, ray_addr, 48), 16,
399      nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
400                  nir_pack_32_2x16_split(b,
401                     nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
402                     nir_unpack_32_2x16_split_x(b,
403                        nir_ishl(b, defs->shader_index_multiplier,
404                                    nir_imm_int(b, 8)))),
405                  nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
406                  nir_pack_32_2x16_split(b,
407                     nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
408                     nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
409      ~0 /* write mask */);
410}
411
412static inline void
413brw_nir_rt_load_mem_ray(nir_builder *b,
414                        struct brw_nir_rt_mem_ray_defs *defs,
415                        enum brw_rt_bvh_level bvh_level)
416{
417   nir_ssa_def *ray_addr = brw_nir_rt_mem_ray_addr(b, bvh_level);
418
419   nir_ssa_def *data[4] = {
420      nir_load_global(b, nir_iadd_imm(b, ray_addr,  0), 16, 4, 32),
421      nir_load_global(b, nir_iadd_imm(b, ray_addr, 16), 16, 4, 32),
422      nir_load_global(b, nir_iadd_imm(b, ray_addr, 32), 16, 4, 32),
423      nir_load_global(b, nir_iadd_imm(b, ray_addr, 48), 16, 4, 32),
424   };
425
426   defs->orig = nir_channels(b, data[0], 0x7);
427   defs->dir = nir_vec3(b, nir_channel(b, data[0], 3),
428                           nir_channel(b, data[1], 0),
429                           nir_channel(b, data[1], 1));
430   defs->t_near = nir_channel(b, data[1], 2);
431   defs->t_far = nir_channel(b, data[1], 3);
432   defs->root_node_ptr =
433      nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
434                                nir_extract_i16(b, nir_channel(b, data[2], 1),
435                                                   nir_imm_int(b, 0)));
436   defs->ray_flags =
437      nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
438   defs->hit_group_sr_base_ptr =
439      nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
440                                nir_extract_i16(b, nir_channel(b, data[2], 3),
441                                                   nir_imm_int(b, 0)));
442   defs->hit_group_sr_stride =
443      nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
444   defs->miss_sr_ptr =
445      nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
446                                nir_extract_i16(b, nir_channel(b, data[3], 1),
447                                                   nir_imm_int(b, 0)));
448   defs->shader_index_multiplier =
449      nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
450                  nir_imm_int(b, 8));
451   defs->inst_leaf_ptr =
452      nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
453                                nir_extract_i16(b, nir_channel(b, data[3], 3),
454                                                   nir_imm_int(b, 0)));
455   defs->ray_mask =
456      nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
457}
458
459struct brw_nir_rt_bvh_instance_leaf_defs {
460   nir_ssa_def *world_to_object[4];
461   nir_ssa_def *instance_id;
462   nir_ssa_def *instance_index;
463   nir_ssa_def *object_to_world[4];
464};
465
466static inline void
467brw_nir_rt_load_bvh_instance_leaf(nir_builder *b,
468                                  struct brw_nir_rt_bvh_instance_leaf_defs *defs,
469                                  nir_ssa_def *leaf_addr)
470{
471   /* We don't care about the first 16B of the leaf for now.  One day, we may
472    * add code to decode it but none of that data is directly required for
473    * implementing any ray-tracing built-ins.
474    */
475
476   defs->world_to_object[0] =
477      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 16), 4, 3, 32);
478   defs->world_to_object[1] =
479      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 28), 4, 3, 32);
480   defs->world_to_object[2] =
481      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 40), 4, 3, 32);
482   /* The last column of the matrices is swapped between the two probably
483    * because it makes it easier/faster for hardware somehow.
484    */
485   defs->object_to_world[3] =
486      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 52), 4, 3, 32);
487
488   nir_ssa_def *data =
489      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 64), 4, 4, 32);
490   defs->instance_id = nir_channel(b, data, 2);
491   defs->instance_index = nir_channel(b, data, 3);
492
493   defs->object_to_world[0] =
494      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 80), 4, 3, 32);
495   defs->object_to_world[1] =
496      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 92), 4, 3, 32);
497   defs->object_to_world[2] =
498      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 104), 4, 3, 32);
499   defs->world_to_object[3] =
500      nir_load_global(b, nir_iadd_imm(b, leaf_addr, 116), 4, 3, 32);
501}
502
503#endif /* BRW_NIR_RT_BUILDER_H */
504