radv_acceleration_structure.c revision 7ec681f3
1/*
2 * Copyright © 2021 Bas Nieuwenhuizen
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23#include "radv_acceleration_structure.h"
24#include "radv_private.h"
25
26#include "util/format/format_utils.h"
27#include "util/half_float.h"
28#include "nir_builder.h"
29#include "radv_cs.h"
30#include "radv_meta.h"
31
32void
33radv_GetAccelerationStructureBuildSizesKHR(
34   VkDevice _device, VkAccelerationStructureBuildTypeKHR buildType,
35   const VkAccelerationStructureBuildGeometryInfoKHR *pBuildInfo,
36   const uint32_t *pMaxPrimitiveCounts, VkAccelerationStructureBuildSizesInfoKHR *pSizeInfo)
37{
38   uint64_t triangles = 0, boxes = 0, instances = 0;
39
40   STATIC_ASSERT(sizeof(struct radv_bvh_triangle_node) == 64);
41   STATIC_ASSERT(sizeof(struct radv_bvh_aabb_node) == 64);
42   STATIC_ASSERT(sizeof(struct radv_bvh_instance_node) == 128);
43   STATIC_ASSERT(sizeof(struct radv_bvh_box16_node) == 64);
44   STATIC_ASSERT(sizeof(struct radv_bvh_box32_node) == 128);
45
46   for (uint32_t i = 0; i < pBuildInfo->geometryCount; ++i) {
47      const VkAccelerationStructureGeometryKHR *geometry;
48      if (pBuildInfo->pGeometries)
49         geometry = &pBuildInfo->pGeometries[i];
50      else
51         geometry = pBuildInfo->ppGeometries[i];
52
53      switch (geometry->geometryType) {
54      case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
55         triangles += pMaxPrimitiveCounts[i];
56         break;
57      case VK_GEOMETRY_TYPE_AABBS_KHR:
58         boxes += pMaxPrimitiveCounts[i];
59         break;
60      case VK_GEOMETRY_TYPE_INSTANCES_KHR:
61         instances += pMaxPrimitiveCounts[i];
62         break;
63      case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
64         unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
65      }
66   }
67
68   uint64_t children = boxes + instances + triangles;
69   uint64_t internal_nodes = 0;
70   while (children > 1) {
71      children = DIV_ROUND_UP(children, 4);
72      internal_nodes += children;
73   }
74
75   /* The stray 128 is to ensure we have space for a header
76    * which we'd want to use for some metadata (like the
77    * total AABB of the BVH) */
78   uint64_t size = boxes * 128 + instances * 128 + triangles * 64 + internal_nodes * 128 + 192;
79
80   pSizeInfo->accelerationStructureSize = size;
81
82   /* 2x the max number of nodes in a BVH layer (one uint32_t each) */
83   pSizeInfo->updateScratchSize = pSizeInfo->buildScratchSize =
84      MAX2(4096, 2 * (boxes + instances + triangles) * sizeof(uint32_t));
85}
86
87VkResult
88radv_CreateAccelerationStructureKHR(VkDevice _device,
89                                    const VkAccelerationStructureCreateInfoKHR *pCreateInfo,
90                                    const VkAllocationCallbacks *pAllocator,
91                                    VkAccelerationStructureKHR *pAccelerationStructure)
92{
93   RADV_FROM_HANDLE(radv_device, device, _device);
94   RADV_FROM_HANDLE(radv_buffer, buffer, pCreateInfo->buffer);
95   struct radv_acceleration_structure *accel;
96
97   accel = vk_alloc2(&device->vk.alloc, pAllocator, sizeof(*accel), 8,
98                     VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
99   if (accel == NULL)
100      return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
101
102   vk_object_base_init(&device->vk, &accel->base, VK_OBJECT_TYPE_ACCELERATION_STRUCTURE_KHR);
103
104   accel->mem_offset = buffer->offset + pCreateInfo->offset;
105   accel->size = pCreateInfo->size;
106   accel->bo = buffer->bo;
107
108   *pAccelerationStructure = radv_acceleration_structure_to_handle(accel);
109   return VK_SUCCESS;
110}
111
112void
113radv_DestroyAccelerationStructureKHR(VkDevice _device,
114                                     VkAccelerationStructureKHR accelerationStructure,
115                                     const VkAllocationCallbacks *pAllocator)
116{
117   RADV_FROM_HANDLE(radv_device, device, _device);
118   RADV_FROM_HANDLE(radv_acceleration_structure, accel, accelerationStructure);
119
120   if (!accel)
121      return;
122
123   vk_object_base_finish(&accel->base);
124   vk_free2(&device->vk.alloc, pAllocator, accel);
125}
126
127VkDeviceAddress
128radv_GetAccelerationStructureDeviceAddressKHR(
129   VkDevice _device, const VkAccelerationStructureDeviceAddressInfoKHR *pInfo)
130{
131   RADV_FROM_HANDLE(radv_acceleration_structure, accel, pInfo->accelerationStructure);
132   return radv_accel_struct_get_va(accel);
133}
134
135VkResult
136radv_WriteAccelerationStructuresPropertiesKHR(
137   VkDevice _device, uint32_t accelerationStructureCount,
138   const VkAccelerationStructureKHR *pAccelerationStructures, VkQueryType queryType,
139   size_t dataSize, void *pData, size_t stride)
140{
141   RADV_FROM_HANDLE(radv_device, device, _device);
142   char *data_out = (char*)pData;
143
144   for (uint32_t i = 0; i < accelerationStructureCount; ++i) {
145      RADV_FROM_HANDLE(radv_acceleration_structure, accel, pAccelerationStructures[i]);
146      const char *base_ptr = (const char *)device->ws->buffer_map(accel->bo);
147      if (!base_ptr)
148         return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
149
150      const struct radv_accel_struct_header *header = (const void*)(base_ptr + accel->mem_offset);
151      if (stride * i + sizeof(VkDeviceSize) <= dataSize) {
152         uint64_t value;
153         switch (queryType) {
154         case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR:
155            value = header->compacted_size;
156            break;
157         case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR:
158            value = header->serialization_size;
159            break;
160         default:
161            unreachable("Unhandled acceleration structure query");
162         }
163         *(VkDeviceSize *)(data_out + stride * i) = value;
164      }
165      device->ws->buffer_unmap(accel->bo);
166   }
167   return VK_SUCCESS;
168}
169
170struct radv_bvh_build_ctx {
171   uint32_t *write_scratch;
172   char *base;
173   char *curr_ptr;
174};
175
176static void
177build_triangles(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
178                const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
179{
180   const VkAccelerationStructureGeometryTrianglesDataKHR *tri_data = &geom->geometry.triangles;
181   VkTransformMatrixKHR matrix;
182   const char *index_data = (const char *)tri_data->indexData.hostAddress + range->primitiveOffset;
183
184   if (tri_data->transformData.hostAddress) {
185      matrix = *(const VkTransformMatrixKHR *)((const char *)tri_data->transformData.hostAddress +
186                                               range->transformOffset);
187   } else {
188      matrix = (VkTransformMatrixKHR){
189         .matrix = {{1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0}}};
190   }
191
192   for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
193      struct radv_bvh_triangle_node *node = (void*)ctx->curr_ptr;
194      uint32_t node_offset = ctx->curr_ptr - ctx->base;
195      uint32_t node_id = node_offset >> 3;
196      *ctx->write_scratch++ = node_id;
197
198      for (unsigned v = 0; v < 3; ++v) {
199         uint32_t v_index = range->firstVertex;
200         switch (tri_data->indexType) {
201         case VK_INDEX_TYPE_NONE_KHR:
202            v_index += p * 3 + v;
203            break;
204         case VK_INDEX_TYPE_UINT8_EXT:
205            v_index += *(const uint8_t *)index_data;
206            index_data += 1;
207            break;
208         case VK_INDEX_TYPE_UINT16:
209            v_index += *(const uint16_t *)index_data;
210            index_data += 2;
211            break;
212         case VK_INDEX_TYPE_UINT32:
213            v_index += *(const uint32_t *)index_data;
214            index_data += 4;
215            break;
216         case VK_INDEX_TYPE_MAX_ENUM:
217            unreachable("Unhandled VK_INDEX_TYPE_MAX_ENUM");
218            break;
219         }
220
221         const char *v_data = (const char *)tri_data->vertexData.hostAddress + v_index * tri_data->vertexStride;
222         float coords[4];
223         switch (tri_data->vertexFormat) {
224         case VK_FORMAT_R32G32_SFLOAT:
225            coords[0] = *(const float *)(v_data + 0);
226            coords[1] = *(const float *)(v_data + 4);
227            coords[2] = 0.0f;
228            coords[3] = 1.0f;
229            break;
230         case VK_FORMAT_R32G32B32_SFLOAT:
231            coords[0] = *(const float *)(v_data + 0);
232            coords[1] = *(const float *)(v_data + 4);
233            coords[2] = *(const float *)(v_data + 8);
234            coords[3] = 1.0f;
235            break;
236         case VK_FORMAT_R32G32B32A32_SFLOAT:
237            coords[0] = *(const float *)(v_data + 0);
238            coords[1] = *(const float *)(v_data + 4);
239            coords[2] = *(const float *)(v_data + 8);
240            coords[3] = *(const float *)(v_data + 12);
241            break;
242         case VK_FORMAT_R16G16_SFLOAT:
243            coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
244            coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
245            coords[2] = 0.0f;
246            coords[3] = 1.0f;
247            break;
248         case VK_FORMAT_R16G16B16_SFLOAT:
249            coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
250            coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
251            coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
252            coords[3] = 1.0f;
253            break;
254         case VK_FORMAT_R16G16B16A16_SFLOAT:
255            coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
256            coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
257            coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
258            coords[3] = _mesa_half_to_float(*(const uint16_t *)(v_data + 6));
259            break;
260         case VK_FORMAT_R16G16_SNORM:
261            coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
262            coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
263            coords[2] = 0.0f;
264            coords[3] = 1.0f;
265            break;
266         case VK_FORMAT_R16G16B16A16_SNORM:
267            coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
268            coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
269            coords[2] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 4), 16);
270            coords[3] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 6), 16);
271            break;
272         case VK_FORMAT_R16G16B16A16_UNORM:
273            coords[0] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 0), 16);
274            coords[1] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 2), 16);
275            coords[2] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 4), 16);
276            coords[3] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 6), 16);
277            break;
278         default:
279            unreachable("Unhandled vertex format in BVH build");
280         }
281
282         for (unsigned j = 0; j < 3; ++j) {
283            float r = 0;
284            for (unsigned k = 0; k < 4; ++k)
285               r += matrix.matrix[j][k] * coords[k];
286            node->coords[v][j] = r;
287         }
288
289         node->triangle_id = p;
290         node->geometry_id_and_flags = geometry_id | (geom->flags << 28);
291
292         /* Seems to be needed for IJ, otherwise I = J = ? */
293         node->id = 9;
294      }
295   }
296}
297
298static VkResult
299build_instances(struct radv_device *device, struct radv_bvh_build_ctx *ctx,
300                const VkAccelerationStructureGeometryKHR *geom,
301                const VkAccelerationStructureBuildRangeInfoKHR *range)
302{
303   const VkAccelerationStructureGeometryInstancesDataKHR *inst_data = &geom->geometry.instances;
304
305   for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 128) {
306      const VkAccelerationStructureInstanceKHR *instance =
307         inst_data->arrayOfPointers
308            ? (((const VkAccelerationStructureInstanceKHR *const *)inst_data->data.hostAddress)[p])
309            : &((const VkAccelerationStructureInstanceKHR *)inst_data->data.hostAddress)[p];
310      if (!instance->accelerationStructureReference) {
311         continue;
312      }
313
314      struct radv_bvh_instance_node *node = (void*)ctx->curr_ptr;
315      uint32_t node_offset = ctx->curr_ptr - ctx->base;
316      uint32_t node_id = (node_offset >> 3) | 6;
317      *ctx->write_scratch++ = node_id;
318
319      float transform[16], inv_transform[16];
320      memcpy(transform, &instance->transform.matrix, sizeof(instance->transform.matrix));
321      transform[12] = transform[13] = transform[14] = 0.0f;
322      transform[15] = 1.0f;
323
324      util_invert_mat4x4(inv_transform, transform);
325      memcpy(node->wto_matrix, inv_transform, sizeof(node->wto_matrix));
326      node->wto_matrix[3] = transform[3];
327      node->wto_matrix[7] = transform[7];
328      node->wto_matrix[11] = transform[11];
329      node->custom_instance_and_mask = instance->instanceCustomIndex | (instance->mask << 24);
330      node->sbt_offset_and_flags =
331         instance->instanceShaderBindingTableRecordOffset | (instance->flags << 24);
332      node->instance_id = p;
333
334      for (unsigned i = 0; i < 3; ++i)
335         for (unsigned j = 0; j < 3; ++j)
336            node->otw_matrix[i * 3 + j] = instance->transform.matrix[j][i];
337
338      RADV_FROM_HANDLE(radv_acceleration_structure, src_accel_struct,
339                       (VkAccelerationStructureKHR)instance->accelerationStructureReference);
340      const void *src_base = device->ws->buffer_map(src_accel_struct->bo);
341      if (!src_base)
342         return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
343
344      src_base = (const char *)src_base + src_accel_struct->mem_offset;
345      const struct radv_accel_struct_header *src_header = src_base;
346      node->base_ptr = radv_accel_struct_get_va(src_accel_struct) | src_header->root_node_offset;
347
348      for (unsigned j = 0; j < 3; ++j) {
349         node->aabb[0][j] = instance->transform.matrix[j][3];
350         node->aabb[1][j] = instance->transform.matrix[j][3];
351         for (unsigned k = 0; k < 3; ++k) {
352            node->aabb[0][j] += MIN2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
353                                     instance->transform.matrix[j][k] * src_header->aabb[1][k]);
354            node->aabb[1][j] += MAX2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
355                                     instance->transform.matrix[j][k] * src_header->aabb[1][k]);
356         }
357      }
358      device->ws->buffer_unmap(src_accel_struct->bo);
359   }
360   return VK_SUCCESS;
361}
362
363static void
364build_aabbs(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
365            const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
366{
367   const VkAccelerationStructureGeometryAabbsDataKHR *aabb_data = &geom->geometry.aabbs;
368
369   for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
370      struct radv_bvh_aabb_node *node = (void*)ctx->curr_ptr;
371      uint32_t node_offset = ctx->curr_ptr - ctx->base;
372      uint32_t node_id = (node_offset >> 3) | 7;
373      *ctx->write_scratch++ = node_id;
374
375      const VkAabbPositionsKHR *aabb =
376         (const VkAabbPositionsKHR *)((const char *)aabb_data->data.hostAddress +
377                                      p * aabb_data->stride);
378
379      node->aabb[0][0] = aabb->minX;
380      node->aabb[0][1] = aabb->minY;
381      node->aabb[0][2] = aabb->minZ;
382      node->aabb[1][0] = aabb->maxX;
383      node->aabb[1][1] = aabb->maxY;
384      node->aabb[1][2] = aabb->maxZ;
385      node->primitive_id = p;
386      node->geometry_id_and_flags = geometry_id;
387   }
388}
389
390static uint32_t
391leaf_node_count(const VkAccelerationStructureBuildGeometryInfoKHR *info,
392                const VkAccelerationStructureBuildRangeInfoKHR *ranges)
393{
394   uint32_t count = 0;
395   for (uint32_t i = 0; i < info->geometryCount; ++i) {
396      count += ranges[i].primitiveCount;
397   }
398   return count;
399}
400
401static void
402compute_bounds(const char *base_ptr, uint32_t node_id, float *bounds)
403{
404   for (unsigned i = 0; i < 3; ++i)
405      bounds[i] = INFINITY;
406   for (unsigned i = 0; i < 3; ++i)
407      bounds[3 + i] = -INFINITY;
408
409   switch (node_id & 7) {
410   case 0: {
411      const struct radv_bvh_triangle_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
412      for (unsigned v = 0; v < 3; ++v) {
413         for (unsigned j = 0; j < 3; ++j) {
414            bounds[j] = MIN2(bounds[j], node->coords[v][j]);
415            bounds[3 + j] = MAX2(bounds[3 + j], node->coords[v][j]);
416         }
417      }
418      break;
419   }
420   case 5: {
421      const struct radv_bvh_box32_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
422      for (unsigned c2 = 0; c2 < 4; ++c2) {
423         if (isnan(node->coords[c2][0][0]))
424            continue;
425         for (unsigned j = 0; j < 3; ++j) {
426            bounds[j] = MIN2(bounds[j], node->coords[c2][0][j]);
427            bounds[3 + j] = MAX2(bounds[3 + j], node->coords[c2][1][j]);
428         }
429      }
430      break;
431   }
432   case 6: {
433      const struct radv_bvh_instance_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
434      for (unsigned j = 0; j < 3; ++j) {
435         bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
436         bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
437      }
438      break;
439   }
440   case 7: {
441      const struct radv_bvh_aabb_node *node = (const void*)(base_ptr + (node_id / 8 * 64));
442      for (unsigned j = 0; j < 3; ++j) {
443         bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
444         bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
445      }
446      break;
447   }
448   }
449}
450
451struct bvh_opt_entry {
452   uint64_t key;
453   uint32_t node_id;
454};
455
456static int
457bvh_opt_compare(const void *_a, const void *_b)
458{
459   const struct bvh_opt_entry *a = _a;
460   const struct bvh_opt_entry *b = _b;
461
462   if (a->key < b->key)
463      return -1;
464   if (a->key > b->key)
465      return 1;
466   if (a->node_id < b->node_id)
467      return -1;
468   if (a->node_id > b->node_id)
469      return 1;
470   return 0;
471}
472
473static void
474optimize_bvh(const char *base_ptr, uint32_t *node_ids, uint32_t node_count)
475{
476   float bounds[6];
477   for (unsigned i = 0; i < 3; ++i)
478      bounds[i] = INFINITY;
479   for (unsigned i = 0; i < 3; ++i)
480      bounds[3 + i] = -INFINITY;
481
482   for (uint32_t i = 0; i < node_count; ++i) {
483      float node_bounds[6];
484      compute_bounds(base_ptr, node_ids[i], node_bounds);
485      for (unsigned j = 0; j < 3; ++j)
486         bounds[j] = MIN2(bounds[j], node_bounds[j]);
487      for (unsigned j = 0; j < 3; ++j)
488         bounds[3 + j] = MAX2(bounds[3 + j], node_bounds[3 + j]);
489   }
490
491   struct bvh_opt_entry *entries = calloc(node_count, sizeof(struct bvh_opt_entry));
492   if (!entries)
493      return;
494
495   for (uint32_t i = 0; i < node_count; ++i) {
496      float node_bounds[6];
497      compute_bounds(base_ptr, node_ids[i], node_bounds);
498      float node_coords[3];
499      for (unsigned j = 0; j < 3; ++j)
500         node_coords[j] = (node_bounds[j] + node_bounds[3 + j]) * 0.5;
501      int32_t coords[3];
502      for (unsigned j = 0; j < 3; ++j)
503         coords[j] = MAX2(
504            MIN2((int32_t)((node_coords[j] - bounds[j]) / (bounds[3 + j] - bounds[j]) * (1 << 21)),
505                 (1 << 21) - 1),
506            0);
507      uint64_t key = 0;
508      for (unsigned j = 0; j < 21; ++j)
509         for (unsigned k = 0; k < 3; ++k)
510            key |= (uint64_t)((coords[k] >> j) & 1) << (j * 3 + k);
511      entries[i].key = key;
512      entries[i].node_id = node_ids[i];
513   }
514
515   qsort(entries, node_count, sizeof(entries[0]), bvh_opt_compare);
516   for (unsigned i = 0; i < node_count; ++i)
517      node_ids[i] = entries[i].node_id;
518
519   free(entries);
520}
521
522static VkResult
523build_bvh(struct radv_device *device, const VkAccelerationStructureBuildGeometryInfoKHR *info,
524          const VkAccelerationStructureBuildRangeInfoKHR *ranges)
525{
526   RADV_FROM_HANDLE(radv_acceleration_structure, accel, info->dstAccelerationStructure);
527   VkResult result = VK_SUCCESS;
528
529   uint32_t *scratch[2];
530   scratch[0] = info->scratchData.hostAddress;
531   scratch[1] = scratch[0] + leaf_node_count(info, ranges);
532
533   char *base_ptr = (char*)device->ws->buffer_map(accel->bo);
534   if (!base_ptr)
535      return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
536
537   base_ptr = base_ptr + accel->mem_offset;
538   struct radv_accel_struct_header *header = (void*)base_ptr;
539   void *first_node_ptr = (char *)base_ptr + ALIGN(sizeof(*header), 64);
540
541   struct radv_bvh_build_ctx ctx = {.write_scratch = scratch[0],
542                                    .base = base_ptr,
543                                    .curr_ptr = (char *)first_node_ptr + 128};
544
545   uint64_t instance_offset = (const char *)ctx.curr_ptr - (const char *)base_ptr;
546   uint64_t instance_count = 0;
547
548   /* This initializes the leaf nodes of the BVH all at the same level. */
549   for (int inst = 1; inst >= 0; --inst) {
550      for (uint32_t i = 0; i < info->geometryCount; ++i) {
551         const VkAccelerationStructureGeometryKHR *geom =
552            info->pGeometries ? &info->pGeometries[i] : info->ppGeometries[i];
553
554         if ((inst && geom->geometryType != VK_GEOMETRY_TYPE_INSTANCES_KHR) ||
555             (!inst && geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
556            continue;
557
558         switch (geom->geometryType) {
559         case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
560            build_triangles(&ctx, geom, ranges + i, i);
561            break;
562         case VK_GEOMETRY_TYPE_AABBS_KHR:
563            build_aabbs(&ctx, geom, ranges + i, i);
564            break;
565         case VK_GEOMETRY_TYPE_INSTANCES_KHR: {
566            result = build_instances(device, &ctx, geom, ranges + i);
567            if (result != VK_SUCCESS)
568               goto fail;
569
570            instance_count += ranges[i].primitiveCount;
571            break;
572         }
573         case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
574            unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
575         }
576      }
577   }
578
579   uint32_t node_counts[2] = {ctx.write_scratch - scratch[0], 0};
580   optimize_bvh(base_ptr, scratch[0], node_counts[0]);
581   unsigned d;
582
583   /*
584    * This is the most naive BVH building algorithm I could think of:
585    * just iteratively builds each level from bottom to top with
586    * the children of each node being in-order and tightly packed.
587    *
588    * Is probably terrible for traversal but should be easy to build an
589    * equivalent GPU version.
590    */
591   for (d = 0; node_counts[d & 1] > 1 || d == 0; ++d) {
592      uint32_t child_count = node_counts[d & 1];
593      const uint32_t *children = scratch[d & 1];
594      uint32_t *dst_ids = scratch[(d & 1) ^ 1];
595      unsigned dst_count;
596      unsigned child_idx = 0;
597      for (dst_count = 0; child_idx < MAX2(1, child_count); ++dst_count, child_idx += 4) {
598         unsigned local_child_count = MIN2(4, child_count - child_idx);
599         uint32_t child_ids[4];
600         float bounds[4][6];
601
602         for (unsigned c = 0; c < local_child_count; ++c) {
603            uint32_t id = children[child_idx + c];
604            child_ids[c] = id;
605
606            compute_bounds(base_ptr, id, bounds[c]);
607         }
608
609         struct radv_bvh_box32_node *node;
610
611         /* Put the root node at base_ptr so the id = 0, which allows some
612          * traversal optimizations. */
613         if (child_idx == 0 && local_child_count == child_count) {
614            node = first_node_ptr;
615            header->root_node_offset = ((char *)first_node_ptr - (char *)base_ptr) / 64 * 8 + 5;
616         } else {
617            uint32_t dst_id = (ctx.curr_ptr - base_ptr) / 64;
618            dst_ids[dst_count] = dst_id * 8 + 5;
619
620            node = (void*)ctx.curr_ptr;
621            ctx.curr_ptr += 128;
622         }
623
624         for (unsigned c = 0; c < local_child_count; ++c) {
625            node->children[c] = child_ids[c];
626            for (unsigned i = 0; i < 2; ++i)
627               for (unsigned j = 0; j < 3; ++j)
628                  node->coords[c][i][j] = bounds[c][i * 3 + j];
629         }
630         for (unsigned c = local_child_count; c < 4; ++c) {
631            for (unsigned i = 0; i < 2; ++i)
632               for (unsigned j = 0; j < 3; ++j)
633                  node->coords[c][i][j] = NAN;
634         }
635      }
636
637      node_counts[(d & 1) ^ 1] = dst_count;
638   }
639
640   compute_bounds(base_ptr, header->root_node_offset, &header->aabb[0][0]);
641
642   header->instance_offset = instance_offset;
643   header->instance_count = instance_count;
644   header->compacted_size = (char *)ctx.curr_ptr - base_ptr;
645
646   /* 16 bytes per invocation, 64 invocations per workgroup */
647   header->copy_dispatch_size[0] = DIV_ROUND_UP(header->compacted_size, 16 * 64);
648   header->copy_dispatch_size[1] = 1;
649   header->copy_dispatch_size[2] = 1;
650
651   header->serialization_size =
652      header->compacted_size + align(sizeof(struct radv_accel_struct_serialization_header) +
653                                        sizeof(uint64_t) * header->instance_count,
654                                     128);
655
656fail:
657   device->ws->buffer_unmap(accel->bo);
658   return result;
659}
660
661VkResult
662radv_BuildAccelerationStructuresKHR(
663   VkDevice _device, VkDeferredOperationKHR deferredOperation, uint32_t infoCount,
664   const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
665   const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
666{
667   RADV_FROM_HANDLE(radv_device, device, _device);
668   VkResult result = VK_SUCCESS;
669
670   for (uint32_t i = 0; i < infoCount; ++i) {
671      result = build_bvh(device, pInfos + i, ppBuildRangeInfos[i]);
672      if (result != VK_SUCCESS)
673         break;
674   }
675   return result;
676}
677
678VkResult
679radv_CopyAccelerationStructureKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
680                                  const VkCopyAccelerationStructureInfoKHR *pInfo)
681{
682   RADV_FROM_HANDLE(radv_device, device, _device);
683   RADV_FROM_HANDLE(radv_acceleration_structure, src_struct, pInfo->src);
684   RADV_FROM_HANDLE(radv_acceleration_structure, dst_struct, pInfo->dst);
685
686   char *src_ptr = (char *)device->ws->buffer_map(src_struct->bo);
687   if (!src_ptr)
688      return VK_ERROR_OUT_OF_HOST_MEMORY;
689
690   char *dst_ptr = (char *)device->ws->buffer_map(dst_struct->bo);
691   if (!dst_ptr) {
692      device->ws->buffer_unmap(src_struct->bo);
693      return VK_ERROR_OUT_OF_HOST_MEMORY;
694   }
695
696   src_ptr += src_struct->mem_offset;
697   dst_ptr += dst_struct->mem_offset;
698
699   const struct radv_accel_struct_header *header = (const void *)src_ptr;
700   memcpy(dst_ptr, src_ptr, header->compacted_size);
701
702   device->ws->buffer_unmap(src_struct->bo);
703   device->ws->buffer_unmap(dst_struct->bo);
704   return VK_SUCCESS;
705}
706
707static nir_ssa_def *
708get_indices(nir_builder *b, nir_ssa_def *addr, nir_ssa_def *type, nir_ssa_def *id)
709{
710   const struct glsl_type *uvec3_type = glsl_vector_type(GLSL_TYPE_UINT, 3);
711   nir_variable *result =
712      nir_variable_create(b->shader, nir_var_shader_temp, uvec3_type, "indices");
713
714   nir_push_if(b, nir_ult(b, type, nir_imm_int(b, 2)));
715   nir_push_if(b, nir_ieq(b, type, nir_imm_int(b, VK_INDEX_TYPE_UINT16)));
716   {
717      nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 6));
718      nir_ssa_def *indices[3];
719      for (unsigned i = 0; i < 3; ++i) {
720         indices[i] = nir_build_load_global(
721            b, 1, 16, nir_iadd(b, addr, nir_u2u64(b, nir_iadd(b, index_id, nir_imm_int(b, 2 * i)))),
722            .align_mul = 2, .align_offset = 0);
723      }
724      nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
725   }
726   nir_push_else(b, NULL);
727   {
728      nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 12));
729      nir_ssa_def *indices = nir_build_load_global(
730         b, 3, 32, nir_iadd(b, addr, nir_u2u64(b, index_id)), .align_mul = 4, .align_offset = 0);
731      nir_store_var(b, result, indices, 7);
732   }
733   nir_pop_if(b, NULL);
734   nir_push_else(b, NULL);
735   {
736      nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 3));
737      nir_ssa_def *indices[] = {
738         index_id,
739         nir_iadd(b, index_id, nir_imm_int(b, 1)),
740         nir_iadd(b, index_id, nir_imm_int(b, 2)),
741      };
742
743      nir_push_if(b, nir_ieq(b, type, nir_imm_int(b, VK_INDEX_TYPE_NONE_KHR)));
744      {
745         nir_store_var(b, result, nir_vec(b, indices, 3), 7);
746      }
747      nir_push_else(b, NULL);
748      {
749         for (unsigned i = 0; i < 3; ++i) {
750            indices[i] = nir_build_load_global(b, 1, 8, nir_iadd(b, addr, nir_u2u64(b, indices[i])),
751                                               .align_mul = 1, .align_offset = 0);
752         }
753         nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
754      }
755      nir_pop_if(b, NULL);
756   }
757   nir_pop_if(b, NULL);
758   return nir_load_var(b, result);
759}
760
761static void
762get_vertices(nir_builder *b, nir_ssa_def *addresses, nir_ssa_def *format, nir_ssa_def *positions[3])
763{
764   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
765   nir_variable *results[3] = {
766      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex0"),
767      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex1"),
768      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex2")};
769
770   VkFormat formats[] = {
771      VK_FORMAT_R32G32B32_SFLOAT,    VK_FORMAT_R32G32B32A32_SFLOAT, VK_FORMAT_R16G16B16_SFLOAT,
772      VK_FORMAT_R16G16B16A16_SFLOAT, VK_FORMAT_R16G16_SFLOAT,       VK_FORMAT_R32G32_SFLOAT,
773      VK_FORMAT_R16G16_SNORM,        VK_FORMAT_R16G16B16A16_SNORM,  VK_FORMAT_R16G16B16A16_UNORM,
774   };
775
776   for (unsigned f = 0; f < ARRAY_SIZE(formats); ++f) {
777      if (f + 1 < ARRAY_SIZE(formats))
778         nir_push_if(b, nir_ieq(b, format, nir_imm_int(b, formats[f])));
779
780      for (unsigned i = 0; i < 3; ++i) {
781         switch (formats[f]) {
782         case VK_FORMAT_R32G32B32_SFLOAT:
783         case VK_FORMAT_R32G32B32A32_SFLOAT:
784            nir_store_var(b, results[i],
785                          nir_build_load_global(b, 3, 32, nir_channel(b, addresses, i),
786                                                .align_mul = 4, .align_offset = 0),
787                          7);
788            break;
789         case VK_FORMAT_R32G32_SFLOAT:
790         case VK_FORMAT_R16G16_SFLOAT:
791         case VK_FORMAT_R16G16B16_SFLOAT:
792         case VK_FORMAT_R16G16B16A16_SFLOAT:
793         case VK_FORMAT_R16G16_SNORM:
794         case VK_FORMAT_R16G16B16A16_SNORM:
795         case VK_FORMAT_R16G16B16A16_UNORM: {
796            unsigned components = MIN2(3, vk_format_get_nr_components(formats[f]));
797            unsigned comp_bits =
798               vk_format_get_blocksizebits(formats[f]) / vk_format_get_nr_components(formats[f]);
799            unsigned comp_bytes = comp_bits / 8;
800            nir_ssa_def *values[3];
801            nir_ssa_def *addr = nir_channel(b, addresses, i);
802            for (unsigned j = 0; j < components; ++j)
803               values[j] = nir_build_load_global(
804                  b, 1, comp_bits, nir_iadd(b, addr, nir_imm_int64(b, j * comp_bytes)),
805                  .align_mul = comp_bytes, .align_offset = 0);
806
807            for (unsigned j = components; j < 3; ++j)
808               values[j] = nir_imm_intN_t(b, 0, comp_bits);
809
810            nir_ssa_def *vec;
811            if (util_format_is_snorm(vk_format_to_pipe_format(formats[f]))) {
812               for (unsigned j = 0; j < 3; ++j) {
813                  values[j] = nir_fdiv(b, nir_i2f32(b, values[j]),
814                                       nir_imm_float(b, (1u << (comp_bits - 1)) - 1));
815                  values[j] = nir_fmax(b, values[j], nir_imm_float(b, -1.0));
816               }
817               vec = nir_vec(b, values, 3);
818            } else if (util_format_is_unorm(vk_format_to_pipe_format(formats[f]))) {
819               for (unsigned j = 0; j < 3; ++j) {
820                  values[j] =
821                     nir_fdiv(b, nir_u2f32(b, values[j]), nir_imm_float(b, (1u << comp_bits) - 1));
822                  values[j] = nir_fmin(b, values[j], nir_imm_float(b, 1.0));
823               }
824               vec = nir_vec(b, values, 3);
825            } else if (comp_bits == 16)
826               vec = nir_f2f32(b, nir_vec(b, values, 3));
827            else
828               vec = nir_vec(b, values, 3);
829            nir_store_var(b, results[i], vec, 7);
830            break;
831         }
832         default:
833            unreachable("Unhandled format");
834         }
835      }
836      if (f + 1 < ARRAY_SIZE(formats))
837         nir_push_else(b, NULL);
838   }
839   for (unsigned f = 1; f < ARRAY_SIZE(formats); ++f) {
840      nir_pop_if(b, NULL);
841   }
842
843   for (unsigned i = 0; i < 3; ++i)
844      positions[i] = nir_load_var(b, results[i]);
845}
846
847struct build_primitive_constants {
848   uint64_t node_dst_addr;
849   uint64_t scratch_addr;
850   uint32_t dst_offset;
851   uint32_t dst_scratch_offset;
852   uint32_t geometry_type;
853   uint32_t geometry_id;
854
855   union {
856      struct {
857         uint64_t vertex_addr;
858         uint64_t index_addr;
859         uint64_t transform_addr;
860         uint32_t vertex_stride;
861         uint32_t vertex_format;
862         uint32_t index_format;
863      };
864      struct {
865         uint64_t instance_data;
866         uint32_t array_of_pointers;
867      };
868      struct {
869         uint64_t aabb_addr;
870         uint32_t aabb_stride;
871      };
872   };
873};
874
875struct build_internal_constants {
876   uint64_t node_dst_addr;
877   uint64_t scratch_addr;
878   uint32_t dst_offset;
879   uint32_t dst_scratch_offset;
880   uint32_t src_scratch_offset;
881   uint32_t fill_header;
882};
883
884/* This inverts a 3x3 matrix using cofactors, as in e.g.
885 * https://www.mathsisfun.com/algebra/matrix-inverse-minors-cofactors-adjugate.html */
886static void
887nir_invert_3x3(nir_builder *b, nir_ssa_def *in[3][3], nir_ssa_def *out[3][3])
888{
889   nir_ssa_def *cofactors[3][3];
890   for (unsigned i = 0; i < 3; ++i) {
891      for (unsigned j = 0; j < 3; ++j) {
892         cofactors[i][j] =
893            nir_fsub(b, nir_fmul(b, in[(i + 1) % 3][(j + 1) % 3], in[(i + 2) % 3][(j + 2) % 3]),
894                     nir_fmul(b, in[(i + 1) % 3][(j + 2) % 3], in[(i + 2) % 3][(j + 1) % 3]));
895      }
896   }
897
898   nir_ssa_def *det = NULL;
899   for (unsigned i = 0; i < 3; ++i) {
900      nir_ssa_def *det_part = nir_fmul(b, in[0][i], cofactors[0][i]);
901      det = det ? nir_fadd(b, det, det_part) : det_part;
902   }
903
904   nir_ssa_def *det_inv = nir_frcp(b, det);
905   for (unsigned i = 0; i < 3; ++i) {
906      for (unsigned j = 0; j < 3; ++j) {
907         out[i][j] = nir_fmul(b, cofactors[j][i], det_inv);
908      }
909   }
910}
911
912static nir_shader *
913build_leaf_shader(struct radv_device *dev)
914{
915   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
916   nir_builder b =
917      nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "accel_build_leaf_shader");
918
919   b.shader->info.workgroup_size[0] = 64;
920   b.shader->info.workgroup_size[1] = 1;
921   b.shader->info.workgroup_size[2] = 1;
922
923   nir_ssa_def *pconst0 =
924      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
925   nir_ssa_def *pconst1 =
926      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
927   nir_ssa_def *pconst2 =
928      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 32, .range = 16);
929   nir_ssa_def *pconst3 =
930      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 48, .range = 16);
931   nir_ssa_def *pconst4 =
932      nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 64, .range = 4);
933
934   nir_ssa_def *geom_type = nir_channel(&b, pconst1, 2);
935   nir_ssa_def *node_dst_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
936   nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 12));
937   nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
938   nir_ssa_def *scratch_offset = nir_channel(&b, pconst1, 1);
939   nir_ssa_def *geometry_id = nir_channel(&b, pconst1, 3);
940
941   nir_ssa_def *global_id =
942      nir_iadd(&b,
943               nir_umul24(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
944                          nir_imm_int(&b, b.shader->info.workgroup_size[0])),
945               nir_channels(&b, nir_load_local_invocation_id(&b), 1));
946   scratch_addr = nir_iadd(
947      &b, scratch_addr,
948      nir_u2u64(&b, nir_iadd(&b, scratch_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 4)))));
949
950   nir_push_if(&b, nir_ieq(&b, geom_type, nir_imm_int(&b, VK_GEOMETRY_TYPE_TRIANGLES_KHR)));
951   { /* Triangles */
952      nir_ssa_def *vertex_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3));
953      nir_ssa_def *index_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 12));
954      nir_ssa_def *transform_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst3, 3));
955      nir_ssa_def *vertex_stride = nir_channel(&b, pconst3, 2);
956      nir_ssa_def *vertex_format = nir_channel(&b, pconst3, 3);
957      nir_ssa_def *index_format = nir_channel(&b, pconst4, 0);
958      unsigned repl_swizzle[4] = {0, 0, 0, 0};
959
960      nir_ssa_def *node_offset =
961         nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 64)));
962      nir_ssa_def *triangle_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
963
964      nir_ssa_def *indices = get_indices(&b, index_addr, index_format, global_id);
965      nir_ssa_def *vertex_addresses = nir_iadd(
966         &b, nir_u2u64(&b, nir_imul(&b, indices, nir_swizzle(&b, vertex_stride, repl_swizzle, 3))),
967         nir_swizzle(&b, vertex_addr, repl_swizzle, 3));
968      nir_ssa_def *positions[3];
969      get_vertices(&b, vertex_addresses, vertex_format, positions);
970
971      nir_ssa_def *node_data[16];
972      memset(node_data, 0, sizeof(node_data));
973
974      nir_variable *transform[] = {
975         nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform0"),
976         nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform1"),
977         nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform2"),
978      };
979      nir_store_var(&b, transform[0], nir_imm_vec4(&b, 1.0, 0.0, 0.0, 0.0), 0xf);
980      nir_store_var(&b, transform[1], nir_imm_vec4(&b, 0.0, 1.0, 0.0, 0.0), 0xf);
981      nir_store_var(&b, transform[2], nir_imm_vec4(&b, 0.0, 0.0, 1.0, 0.0), 0xf);
982
983      nir_push_if(&b, nir_ine(&b, transform_addr, nir_imm_int64(&b, 0)));
984      nir_store_var(
985         &b, transform[0],
986         nir_build_load_global(&b, 4, 32, nir_iadd(&b, transform_addr, nir_imm_int64(&b, 0)),
987                               .align_mul = 4, .align_offset = 0),
988         0xf);
989      nir_store_var(
990         &b, transform[1],
991         nir_build_load_global(&b, 4, 32, nir_iadd(&b, transform_addr, nir_imm_int64(&b, 16)),
992                               .align_mul = 4, .align_offset = 0),
993         0xf);
994      nir_store_var(
995         &b, transform[2],
996         nir_build_load_global(&b, 4, 32, nir_iadd(&b, transform_addr, nir_imm_int64(&b, 32)),
997                               .align_mul = 4, .align_offset = 0),
998         0xf);
999      nir_pop_if(&b, NULL);
1000
1001      for (unsigned i = 0; i < 3; ++i)
1002         for (unsigned j = 0; j < 3; ++j)
1003            node_data[i * 3 + j] = nir_fdph(&b, positions[i], nir_load_var(&b, transform[j]));
1004
1005      node_data[12] = global_id;
1006      node_data[13] = geometry_id;
1007      node_data[15] = nir_imm_int(&b, 9);
1008      for (unsigned i = 0; i < ARRAY_SIZE(node_data); ++i)
1009         if (!node_data[i])
1010            node_data[i] = nir_imm_int(&b, 0);
1011
1012      for (unsigned i = 0; i < 4; ++i) {
1013         nir_build_store_global(&b, nir_vec(&b, node_data + i * 4, 4),
1014                                nir_iadd(&b, triangle_node_dst_addr, nir_imm_int64(&b, i * 16)),
1015                                .write_mask = 15, .align_mul = 16, .align_offset = 0);
1016      }
1017
1018      nir_ssa_def *node_id = nir_ushr(&b, node_offset, nir_imm_int(&b, 3));
1019      nir_build_store_global(&b, node_id, scratch_addr, .write_mask = 1, .align_mul = 4,
1020                             .align_offset = 0);
1021   }
1022   nir_push_else(&b, NULL);
1023   nir_push_if(&b, nir_ieq(&b, geom_type, nir_imm_int(&b, VK_GEOMETRY_TYPE_AABBS_KHR)));
1024   { /* AABBs */
1025      nir_ssa_def *aabb_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3));
1026      nir_ssa_def *aabb_stride = nir_channel(&b, pconst2, 2);
1027
1028      nir_ssa_def *node_offset =
1029         nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 64)));
1030      nir_ssa_def *aabb_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1031      nir_ssa_def *node_id =
1032         nir_iadd(&b, nir_ushr(&b, node_offset, nir_imm_int(&b, 3)), nir_imm_int(&b, 7));
1033      nir_build_store_global(&b, node_id, scratch_addr, .write_mask = 1, .align_mul = 4,
1034                             .align_offset = 0);
1035
1036      aabb_addr = nir_iadd(&b, aabb_addr, nir_u2u64(&b, nir_imul(&b, aabb_stride, global_id)));
1037
1038      nir_ssa_def *min_bound =
1039         nir_build_load_global(&b, 3, 32, nir_iadd(&b, aabb_addr, nir_imm_int64(&b, 0)),
1040                               .align_mul = 4, .align_offset = 0);
1041      nir_ssa_def *max_bound =
1042         nir_build_load_global(&b, 3, 32, nir_iadd(&b, aabb_addr, nir_imm_int64(&b, 12)),
1043                               .align_mul = 4, .align_offset = 0);
1044
1045      nir_ssa_def *values[] = {nir_channel(&b, min_bound, 0),
1046                               nir_channel(&b, min_bound, 1),
1047                               nir_channel(&b, min_bound, 2),
1048                               nir_channel(&b, max_bound, 0),
1049                               nir_channel(&b, max_bound, 1),
1050                               nir_channel(&b, max_bound, 2),
1051                               global_id,
1052                               geometry_id};
1053
1054      nir_build_store_global(&b, nir_vec(&b, values + 0, 4),
1055                             nir_iadd(&b, aabb_node_dst_addr, nir_imm_int64(&b, 0)),
1056                             .write_mask = 15, .align_mul = 16, .align_offset = 0);
1057      nir_build_store_global(&b, nir_vec(&b, values + 4, 4),
1058                             nir_iadd(&b, aabb_node_dst_addr, nir_imm_int64(&b, 16)),
1059                             .write_mask = 15, .align_mul = 16, .align_offset = 0);
1060   }
1061   nir_push_else(&b, NULL);
1062   { /* Instances */
1063
1064      nir_variable *instance_addr_var =
1065         nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1066      nir_push_if(&b, nir_ine(&b, nir_channel(&b, pconst2, 2), nir_imm_int(&b, 0)));
1067      {
1068         nir_ssa_def *ptr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3)),
1069                                     nir_u2u64(&b, nir_imul(&b, global_id, nir_imm_int(&b, 8))));
1070         nir_ssa_def *addr = nir_pack_64_2x32(
1071            &b, nir_build_load_global(&b, 2, 32, ptr, .align_mul = 8, .align_offset = 0));
1072         nir_store_var(&b, instance_addr_var, addr, 1);
1073      }
1074      nir_push_else(&b, NULL);
1075      {
1076         nir_ssa_def *addr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3)),
1077                                      nir_u2u64(&b, nir_imul(&b, global_id, nir_imm_int(&b, 64))));
1078         nir_store_var(&b, instance_addr_var, addr, 1);
1079      }
1080      nir_pop_if(&b, NULL);
1081      nir_ssa_def *instance_addr = nir_load_var(&b, instance_addr_var);
1082
1083      nir_ssa_def *inst_transform[] = {
1084         nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 0)),
1085                               .align_mul = 4, .align_offset = 0),
1086         nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 16)),
1087                               .align_mul = 4, .align_offset = 0),
1088         nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 32)),
1089                               .align_mul = 4, .align_offset = 0)};
1090      nir_ssa_def *inst3 =
1091         nir_build_load_global(&b, 4, 32, nir_iadd(&b, instance_addr, nir_imm_int64(&b, 48)),
1092                               .align_mul = 4, .align_offset = 0);
1093
1094      nir_ssa_def *node_offset =
1095         nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 128)));
1096      node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1097      nir_ssa_def *node_id =
1098         nir_iadd(&b, nir_ushr(&b, node_offset, nir_imm_int(&b, 3)), nir_imm_int(&b, 6));
1099      nir_build_store_global(&b, node_id, scratch_addr, .write_mask = 1, .align_mul = 4,
1100                             .align_offset = 0);
1101
1102      nir_variable *bounds[2] = {
1103         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1104         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1105      };
1106
1107      nir_store_var(&b, bounds[0], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1108      nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1109
1110      nir_ssa_def *header_addr = nir_pack_64_2x32(&b, nir_channels(&b, inst3, 12));
1111      nir_push_if(&b, nir_ine(&b, header_addr, nir_imm_int64(&b, 0)));
1112      nir_ssa_def *header_root_offset =
1113         nir_build_load_global(&b, 1, 32, nir_iadd(&b, header_addr, nir_imm_int64(&b, 0)),
1114                               .align_mul = 4, .align_offset = 0);
1115      nir_ssa_def *header_min =
1116         nir_build_load_global(&b, 3, 32, nir_iadd(&b, header_addr, nir_imm_int64(&b, 8)),
1117                               .align_mul = 4, .align_offset = 0);
1118      nir_ssa_def *header_max =
1119         nir_build_load_global(&b, 3, 32, nir_iadd(&b, header_addr, nir_imm_int64(&b, 20)),
1120                               .align_mul = 4, .align_offset = 0);
1121
1122      nir_ssa_def *bound_defs[2][3];
1123      for (unsigned i = 0; i < 3; ++i) {
1124         bound_defs[0][i] = bound_defs[1][i] = nir_channel(&b, inst_transform[i], 3);
1125
1126         nir_ssa_def *mul_a = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_min);
1127         nir_ssa_def *mul_b = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_max);
1128         nir_ssa_def *mi = nir_fmin(&b, mul_a, mul_b);
1129         nir_ssa_def *ma = nir_fmax(&b, mul_a, mul_b);
1130         for (unsigned j = 0; j < 3; ++j) {
1131            bound_defs[0][i] = nir_fadd(&b, bound_defs[0][i], nir_channel(&b, mi, j));
1132            bound_defs[1][i] = nir_fadd(&b, bound_defs[1][i], nir_channel(&b, ma, j));
1133         }
1134      }
1135
1136      nir_store_var(&b, bounds[0], nir_vec(&b, bound_defs[0], 3), 7);
1137      nir_store_var(&b, bounds[1], nir_vec(&b, bound_defs[1], 3), 7);
1138
1139      /* Store object to world matrix */
1140      for (unsigned i = 0; i < 3; ++i) {
1141         nir_ssa_def *vals[3];
1142         for (unsigned j = 0; j < 3; ++j)
1143            vals[j] = nir_channel(&b, inst_transform[j], i);
1144
1145         nir_build_store_global(&b, nir_vec(&b, vals, 3),
1146                                nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 92 + 12 * i)),
1147                                .write_mask = 0x7, .align_mul = 4, .align_offset = 0);
1148      }
1149
1150      nir_ssa_def *m_in[3][3], *m_out[3][3], *m_vec[3][4];
1151      for (unsigned i = 0; i < 3; ++i)
1152         for (unsigned j = 0; j < 3; ++j)
1153            m_in[i][j] = nir_channel(&b, inst_transform[i], j);
1154      nir_invert_3x3(&b, m_in, m_out);
1155      for (unsigned i = 0; i < 3; ++i) {
1156         for (unsigned j = 0; j < 3; ++j)
1157            m_vec[i][j] = m_out[i][j];
1158         m_vec[i][3] = nir_channel(&b, inst_transform[i], 3);
1159      }
1160
1161      for (unsigned i = 0; i < 3; ++i) {
1162         nir_build_store_global(&b, nir_vec(&b, m_vec[i], 4),
1163                                nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 16 + 16 * i)),
1164                                .write_mask = 0xf, .align_mul = 4, .align_offset = 0);
1165      }
1166
1167      nir_ssa_def *out0[4] = {
1168         nir_ior(&b, nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 0), header_root_offset),
1169         nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 1), nir_channel(&b, inst3, 0),
1170         nir_channel(&b, inst3, 1)};
1171      nir_build_store_global(&b, nir_vec(&b, out0, 4),
1172                             nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 0)), .write_mask = 0xf,
1173                             .align_mul = 4, .align_offset = 0);
1174      nir_build_store_global(&b, global_id, nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 88)),
1175                             .write_mask = 0x1, .align_mul = 4, .align_offset = 0);
1176      nir_pop_if(&b, NULL);
1177      nir_build_store_global(&b, nir_load_var(&b, bounds[0]),
1178                             nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 64)), .write_mask = 0x7,
1179                             .align_mul = 4, .align_offset = 0);
1180      nir_build_store_global(&b, nir_load_var(&b, bounds[1]),
1181                             nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 76)), .write_mask = 0x7,
1182                             .align_mul = 4, .align_offset = 0);
1183   }
1184   nir_pop_if(&b, NULL);
1185   nir_pop_if(&b, NULL);
1186
1187   return b.shader;
1188}
1189
1190static void
1191determine_bounds(nir_builder *b, nir_ssa_def *node_addr, nir_ssa_def *node_id,
1192                 nir_variable *bounds_vars[2])
1193{
1194   nir_ssa_def *node_type = nir_iand(b, node_id, nir_imm_int(b, 7));
1195   node_addr = nir_iadd(
1196      b, node_addr,
1197      nir_u2u64(b, nir_ishl(b, nir_iand(b, node_id, nir_imm_int(b, ~7u)), nir_imm_int(b, 3))));
1198
1199   nir_push_if(b, nir_ieq(b, node_type, nir_imm_int(b, 0)));
1200   {
1201      nir_ssa_def *positions[3];
1202      for (unsigned i = 0; i < 3; ++i)
1203         positions[i] =
1204            nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, i * 12)),
1205                                  .align_mul = 4, .align_offset = 0);
1206      nir_ssa_def *bounds[] = {positions[0], positions[0]};
1207      for (unsigned i = 1; i < 3; ++i) {
1208         bounds[0] = nir_fmin(b, bounds[0], positions[i]);
1209         bounds[1] = nir_fmax(b, bounds[1], positions[i]);
1210      }
1211      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1212      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1213   }
1214   nir_push_else(b, NULL);
1215   nir_push_if(b, nir_ieq(b, node_type, nir_imm_int(b, 5)));
1216   {
1217      nir_ssa_def *input_bounds[4][2];
1218      for (unsigned i = 0; i < 4; ++i)
1219         for (unsigned j = 0; j < 2; ++j)
1220            input_bounds[i][j] = nir_build_load_global(
1221               b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 16 + i * 24 + j * 12)),
1222               .align_mul = 4, .align_offset = 0);
1223      nir_ssa_def *bounds[] = {input_bounds[0][0], input_bounds[0][1]};
1224      for (unsigned i = 1; i < 4; ++i) {
1225         bounds[0] = nir_fmin(b, bounds[0], input_bounds[i][0]);
1226         bounds[1] = nir_fmax(b, bounds[1], input_bounds[i][1]);
1227      }
1228
1229      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1230      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1231   }
1232   nir_push_else(b, NULL);
1233   nir_push_if(b, nir_ieq(b, node_type, nir_imm_int(b, 6)));
1234   { /* Instances */
1235      nir_ssa_def *bounds[2];
1236      for (unsigned i = 0; i < 2; ++i)
1237         bounds[i] =
1238            nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 64 + i * 12)),
1239                                  .align_mul = 4, .align_offset = 0);
1240      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1241      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1242   }
1243   nir_push_else(b, NULL);
1244   { /* AABBs */
1245      nir_ssa_def *bounds[2];
1246      for (unsigned i = 0; i < 2; ++i)
1247         bounds[i] =
1248            nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, i * 12)),
1249                                  .align_mul = 4, .align_offset = 0);
1250      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1251      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1252   }
1253   nir_pop_if(b, NULL);
1254   nir_pop_if(b, NULL);
1255   nir_pop_if(b, NULL);
1256}
1257
1258static nir_shader *
1259build_internal_shader(struct radv_device *dev)
1260{
1261   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1262   nir_builder b =
1263      nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "accel_build_internal_shader");
1264
1265   b.shader->info.workgroup_size[0] = 64;
1266   b.shader->info.workgroup_size[1] = 1;
1267   b.shader->info.workgroup_size[2] = 1;
1268
1269   /*
1270    * push constants:
1271    *   i32 x 2: node dst address
1272    *   i32 x 2: scratch address
1273    *   i32: dst offset
1274    *   i32: dst scratch offset
1275    *   i32: src scratch offset
1276    *   i32: src_node_count | (fill_header << 31)
1277    */
1278   nir_ssa_def *pconst0 =
1279      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1280   nir_ssa_def *pconst1 =
1281      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
1282
1283   nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
1284   nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 12));
1285   nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
1286   nir_ssa_def *dst_scratch_offset = nir_channel(&b, pconst1, 1);
1287   nir_ssa_def *src_scratch_offset = nir_channel(&b, pconst1, 2);
1288   nir_ssa_def *src_node_count =
1289      nir_iand(&b, nir_channel(&b, pconst1, 3), nir_imm_int(&b, 0x7FFFFFFFU));
1290   nir_ssa_def *fill_header =
1291      nir_ine(&b, nir_iand(&b, nir_channel(&b, pconst1, 3), nir_imm_int(&b, 0x80000000U)),
1292              nir_imm_int(&b, 0));
1293
1294   nir_ssa_def *global_id =
1295      nir_iadd(&b,
1296               nir_umul24(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
1297                          nir_imm_int(&b, b.shader->info.workgroup_size[0])),
1298               nir_channels(&b, nir_load_local_invocation_id(&b), 1));
1299   nir_ssa_def *src_idx = nir_imul(&b, global_id, nir_imm_int(&b, 4));
1300   nir_ssa_def *src_count = nir_umin(&b, nir_imm_int(&b, 4), nir_isub(&b, src_node_count, src_idx));
1301
1302   nir_ssa_def *node_offset =
1303      nir_iadd(&b, node_dst_offset, nir_ishl(&b, global_id, nir_imm_int(&b, 7)));
1304   nir_ssa_def *node_dst_addr = nir_iadd(&b, node_addr, nir_u2u64(&b, node_offset));
1305   nir_ssa_def *src_nodes = nir_build_load_global(
1306      &b, 4, 32,
1307      nir_iadd(&b, scratch_addr,
1308               nir_u2u64(&b, nir_iadd(&b, src_scratch_offset,
1309                                      nir_ishl(&b, global_id, nir_imm_int(&b, 4))))),
1310      .align_mul = 4, .align_offset = 0);
1311
1312   nir_build_store_global(&b, src_nodes, nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 0)),
1313                          .write_mask = 0xf, .align_mul = 4, .align_offset = 0);
1314
1315   nir_ssa_def *total_bounds[2] = {
1316      nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1317      nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1318   };
1319
1320   for (unsigned i = 0; i < 4; ++i) {
1321      nir_variable *bounds[2] = {
1322         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1323         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1324      };
1325      nir_store_var(&b, bounds[0], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1326      nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1327
1328      nir_push_if(&b, nir_ilt(&b, nir_imm_int(&b, i), src_count));
1329      determine_bounds(&b, node_addr, nir_channel(&b, src_nodes, i), bounds);
1330      nir_pop_if(&b, NULL);
1331      nir_build_store_global(&b, nir_load_var(&b, bounds[0]),
1332                             nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 16 + 24 * i)),
1333                             .write_mask = 0x7, .align_mul = 4, .align_offset = 0);
1334      nir_build_store_global(&b, nir_load_var(&b, bounds[1]),
1335                             nir_iadd(&b, node_dst_addr, nir_imm_int64(&b, 28 + 24 * i)),
1336                             .write_mask = 0x7, .align_mul = 4, .align_offset = 0);
1337      total_bounds[0] = nir_fmin(&b, total_bounds[0], nir_load_var(&b, bounds[0]));
1338      total_bounds[1] = nir_fmax(&b, total_bounds[1], nir_load_var(&b, bounds[1]));
1339   }
1340
1341   nir_ssa_def *node_id =
1342      nir_iadd(&b, nir_ushr(&b, node_offset, nir_imm_int(&b, 3)), nir_imm_int(&b, 5));
1343   nir_ssa_def *dst_scratch_addr = nir_iadd(
1344      &b, scratch_addr,
1345      nir_u2u64(&b, nir_iadd(&b, dst_scratch_offset, nir_ishl(&b, global_id, nir_imm_int(&b, 2)))));
1346   nir_build_store_global(&b, node_id, dst_scratch_addr, .write_mask = 1, .align_mul = 4,
1347                          .align_offset = 0);
1348
1349   nir_push_if(&b, fill_header);
1350   nir_build_store_global(&b, node_id, node_addr, .write_mask = 1, .align_mul = 4,
1351                          .align_offset = 0);
1352   nir_build_store_global(&b, total_bounds[0], nir_iadd(&b, node_addr, nir_imm_int64(&b, 8)),
1353                          .write_mask = 7, .align_mul = 4, .align_offset = 0);
1354   nir_build_store_global(&b, total_bounds[1], nir_iadd(&b, node_addr, nir_imm_int64(&b, 20)),
1355                          .write_mask = 7, .align_mul = 4, .align_offset = 0);
1356   nir_pop_if(&b, NULL);
1357   return b.shader;
1358}
1359
1360enum copy_mode {
1361   COPY_MODE_COPY,
1362   COPY_MODE_SERIALIZE,
1363   COPY_MODE_DESERIALIZE,
1364};
1365
1366struct copy_constants {
1367   uint64_t src_addr;
1368   uint64_t dst_addr;
1369   uint32_t mode;
1370};
1371
1372static nir_shader *
1373build_copy_shader(struct radv_device *dev)
1374{
1375   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "accel_copy");
1376   b.shader->info.workgroup_size[0] = 64;
1377   b.shader->info.workgroup_size[1] = 1;
1378   b.shader->info.workgroup_size[2] = 1;
1379
1380   nir_ssa_def *invoc_id = nir_load_local_invocation_id(&b);
1381   nir_ssa_def *wg_id = nir_load_workgroup_id(&b, 32);
1382   nir_ssa_def *block_size =
1383      nir_imm_ivec4(&b, b.shader->info.workgroup_size[0], b.shader->info.workgroup_size[1],
1384                    b.shader->info.workgroup_size[2], 0);
1385
1386   nir_ssa_def *global_id =
1387      nir_channel(&b, nir_iadd(&b, nir_imul(&b, wg_id, block_size), invoc_id), 0);
1388
1389   nir_variable *offset_var =
1390      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "offset");
1391   nir_ssa_def *offset = nir_imul(&b, global_id, nir_imm_int(&b, 16));
1392   nir_store_var(&b, offset_var, offset, 1);
1393
1394   nir_ssa_def *increment = nir_imul(&b, nir_channel(&b, nir_load_num_workgroups(&b, 32), 0),
1395                                     nir_imm_int(&b, b.shader->info.workgroup_size[0] * 16));
1396
1397   nir_ssa_def *pconst0 =
1398      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1399   nir_ssa_def *pconst1 =
1400      nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 16, .range = 4);
1401   nir_ssa_def *src_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
1402   nir_ssa_def *dst_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0xc));
1403   nir_ssa_def *mode = nir_channel(&b, pconst1, 0);
1404
1405   nir_variable *compacted_size_var =
1406      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "compacted_size");
1407   nir_variable *src_offset_var =
1408      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "src_offset");
1409   nir_variable *dst_offset_var =
1410      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "dst_offset");
1411   nir_variable *instance_offset_var =
1412      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_offset");
1413   nir_variable *instance_count_var =
1414      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_count");
1415   nir_variable *value_var =
1416      nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "value");
1417
1418   nir_push_if(&b, nir_ieq(&b, mode, nir_imm_int(&b, COPY_MODE_SERIALIZE)));
1419   {
1420      nir_ssa_def *instance_count = nir_build_load_global(
1421         &b, 1, 32,
1422         nir_iadd(&b, src_base_addr,
1423                  nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, instance_count))),
1424         .align_mul = 4, .align_offset = 0);
1425      nir_ssa_def *compacted_size = nir_build_load_global(
1426         &b, 1, 64,
1427         nir_iadd(&b, src_base_addr,
1428                  nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1429         .align_mul = 8, .align_offset = 0);
1430      nir_ssa_def *serialization_size = nir_build_load_global(
1431         &b, 1, 64,
1432         nir_iadd(&b, src_base_addr,
1433                  nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, serialization_size))),
1434         .align_mul = 8, .align_offset = 0);
1435
1436      nir_store_var(&b, compacted_size_var, compacted_size, 1);
1437      nir_store_var(
1438         &b, instance_offset_var,
1439         nir_build_load_global(
1440            &b, 1, 32,
1441            nir_iadd(&b, src_base_addr,
1442                     nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, instance_offset))),
1443            .align_mul = 4, .align_offset = 0),
1444         1);
1445      nir_store_var(&b, instance_count_var, instance_count, 1);
1446
1447      nir_ssa_def *dst_offset =
1448         nir_iadd(&b, nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)),
1449                  nir_imul(&b, instance_count, nir_imm_int(&b, sizeof(uint64_t))));
1450      nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1451      nir_store_var(&b, dst_offset_var, dst_offset, 1);
1452
1453      nir_push_if(&b, nir_ieq(&b, global_id, nir_imm_int(&b, 0)));
1454      {
1455         nir_build_store_global(
1456            &b, serialization_size,
1457            nir_iadd(&b, dst_base_addr,
1458                     nir_imm_int64(&b, offsetof(struct radv_accel_struct_serialization_header,
1459                                                serialization_size))),
1460            .write_mask = 0x1, .align_mul = 8, .align_offset = 0);
1461         nir_build_store_global(
1462            &b, compacted_size,
1463            nir_iadd(&b, dst_base_addr,
1464                     nir_imm_int64(&b, offsetof(struct radv_accel_struct_serialization_header,
1465                                                compacted_size))),
1466            .write_mask = 0x1, .align_mul = 8, .align_offset = 0);
1467         nir_build_store_global(
1468            &b, nir_u2u64(&b, instance_count),
1469            nir_iadd(&b, dst_base_addr,
1470                     nir_imm_int64(&b, offsetof(struct radv_accel_struct_serialization_header,
1471                                                instance_count))),
1472            .write_mask = 0x1, .align_mul = 8, .align_offset = 0);
1473      }
1474      nir_pop_if(&b, NULL);
1475   }
1476   nir_push_else(&b, NULL);
1477   nir_push_if(&b, nir_ieq(&b, mode, nir_imm_int(&b, COPY_MODE_DESERIALIZE)));
1478   {
1479      nir_ssa_def *instance_count = nir_build_load_global(
1480         &b, 1, 32,
1481         nir_iadd(&b, src_base_addr,
1482                  nir_imm_int64(
1483                     &b, offsetof(struct radv_accel_struct_serialization_header, instance_count))),
1484         .align_mul = 4, .align_offset = 0);
1485      nir_ssa_def *src_offset =
1486         nir_iadd(&b, nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)),
1487                  nir_imul(&b, instance_count, nir_imm_int(&b, sizeof(uint64_t))));
1488
1489      nir_ssa_def *header_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1490      nir_store_var(
1491         &b, compacted_size_var,
1492         nir_build_load_global(
1493            &b, 1, 64,
1494            nir_iadd(&b, header_addr,
1495                     nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1496            .align_mul = 8, .align_offset = 0),
1497         1);
1498      nir_store_var(
1499         &b, instance_offset_var,
1500         nir_build_load_global(
1501            &b, 1, 32,
1502            nir_iadd(&b, header_addr,
1503                     nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, instance_offset))),
1504            .align_mul = 4, .align_offset = 0),
1505         1);
1506      nir_store_var(&b, instance_count_var, instance_count, 1);
1507      nir_store_var(&b, src_offset_var, src_offset, 1);
1508      nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1509   }
1510   nir_push_else(&b, NULL); /* COPY_MODE_COPY */
1511   {
1512      nir_store_var(
1513         &b, compacted_size_var,
1514         nir_build_load_global(
1515            &b, 1, 64,
1516            nir_iadd(&b, src_base_addr,
1517                     nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1518            .align_mul = 8, .align_offset = 0),
1519         1);
1520
1521      nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1522      nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1523      nir_store_var(&b, instance_offset_var, nir_imm_int(&b, 0), 1);
1524      nir_store_var(&b, instance_count_var, nir_imm_int(&b, 0), 1);
1525   }
1526   nir_pop_if(&b, NULL);
1527   nir_pop_if(&b, NULL);
1528
1529   nir_ssa_def *instance_bound =
1530      nir_imul(&b, nir_imm_int(&b, sizeof(struct radv_bvh_instance_node)),
1531               nir_load_var(&b, instance_count_var));
1532   nir_ssa_def *compacted_size = nir_build_load_global(
1533      &b, 1, 32,
1534      nir_iadd(&b, src_base_addr,
1535               nir_imm_int64(&b, offsetof(struct radv_accel_struct_header, compacted_size))),
1536      .align_mul = 4, .align_offset = 0);
1537
1538   nir_push_loop(&b);
1539   {
1540      offset = nir_load_var(&b, offset_var);
1541      nir_push_if(&b, nir_ilt(&b, offset, compacted_size));
1542      {
1543         nir_ssa_def *src_offset = nir_iadd(&b, offset, nir_load_var(&b, src_offset_var));
1544         nir_ssa_def *dst_offset = nir_iadd(&b, offset, nir_load_var(&b, dst_offset_var));
1545         nir_ssa_def *src_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1546         nir_ssa_def *dst_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, dst_offset));
1547
1548         nir_ssa_def *value =
1549            nir_build_load_global(&b, 4, 32, src_addr, .align_mul = 16, .align_offset = 0);
1550         nir_store_var(&b, value_var, value, 0xf);
1551
1552         nir_ssa_def *instance_offset = nir_isub(&b, offset, nir_load_var(&b, instance_offset_var));
1553         nir_ssa_def *in_instance_bound =
1554            nir_iand(&b, nir_uge(&b, offset, nir_load_var(&b, instance_offset_var)),
1555                     nir_ult(&b, instance_offset, instance_bound));
1556         nir_ssa_def *instance_start =
1557            nir_ieq(&b,
1558                    nir_iand(&b, instance_offset,
1559                             nir_imm_int(&b, sizeof(struct radv_bvh_instance_node) - 1)),
1560                    nir_imm_int(&b, 0));
1561
1562         nir_push_if(&b, nir_iand(&b, in_instance_bound, instance_start));
1563         {
1564            nir_ssa_def *instance_id = nir_ushr(&b, instance_offset, nir_imm_int(&b, 7));
1565
1566            nir_push_if(&b, nir_ieq(&b, mode, nir_imm_int(&b, COPY_MODE_SERIALIZE)));
1567            {
1568               nir_ssa_def *instance_addr =
1569                  nir_imul(&b, instance_id, nir_imm_int(&b, sizeof(uint64_t)));
1570               instance_addr =
1571                  nir_iadd(&b, instance_addr,
1572                           nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)));
1573               instance_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, instance_addr));
1574
1575               nir_build_store_global(&b, nir_channels(&b, value, 3), instance_addr,
1576                                      .write_mask = 3, .align_mul = 8, .align_offset = 0);
1577            }
1578            nir_push_else(&b, NULL);
1579            {
1580               nir_ssa_def *instance_addr =
1581                  nir_imul(&b, instance_id, nir_imm_int(&b, sizeof(uint64_t)));
1582               instance_addr =
1583                  nir_iadd(&b, instance_addr,
1584                           nir_imm_int(&b, sizeof(struct radv_accel_struct_serialization_header)));
1585               instance_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, instance_addr));
1586
1587               nir_ssa_def *instance_value = nir_build_load_global(
1588                  &b, 2, 32, instance_addr, .align_mul = 8, .align_offset = 0);
1589
1590               nir_ssa_def *values[] = {
1591                  nir_channel(&b, instance_value, 0),
1592                  nir_channel(&b, instance_value, 1),
1593                  nir_channel(&b, value, 2),
1594                  nir_channel(&b, value, 3),
1595               };
1596
1597               nir_store_var(&b, value_var, nir_vec(&b, values, 4), 0xf);
1598            }
1599            nir_pop_if(&b, NULL);
1600         }
1601         nir_pop_if(&b, NULL);
1602
1603         nir_store_var(&b, offset_var, nir_iadd(&b, offset, increment), 1);
1604
1605         nir_build_store_global(&b, nir_load_var(&b, value_var), dst_addr, .write_mask = 0xf,
1606                                .align_mul = 16, .align_offset = 0);
1607      }
1608      nir_push_else(&b, NULL);
1609      {
1610         nir_jump(&b, nir_jump_break);
1611      }
1612      nir_pop_if(&b, NULL);
1613   }
1614   nir_pop_loop(&b, NULL);
1615   return b.shader;
1616}
1617
1618void
1619radv_device_finish_accel_struct_build_state(struct radv_device *device)
1620{
1621   struct radv_meta_state *state = &device->meta_state;
1622   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.copy_pipeline,
1623                        &state->alloc);
1624   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.internal_pipeline,
1625                        &state->alloc);
1626   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.leaf_pipeline,
1627                        &state->alloc);
1628   radv_DestroyPipelineLayout(radv_device_to_handle(device),
1629                              state->accel_struct_build.copy_p_layout, &state->alloc);
1630   radv_DestroyPipelineLayout(radv_device_to_handle(device),
1631                              state->accel_struct_build.internal_p_layout, &state->alloc);
1632   radv_DestroyPipelineLayout(radv_device_to_handle(device),
1633                              state->accel_struct_build.leaf_p_layout, &state->alloc);
1634}
1635
1636VkResult
1637radv_device_init_accel_struct_build_state(struct radv_device *device)
1638{
1639   VkResult result;
1640   nir_shader *leaf_cs = build_leaf_shader(device);
1641   nir_shader *internal_cs = build_internal_shader(device);
1642   nir_shader *copy_cs = build_copy_shader(device);
1643
1644   const VkPipelineLayoutCreateInfo leaf_pl_create_info = {
1645      .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1646      .setLayoutCount = 0,
1647      .pushConstantRangeCount = 1,
1648      .pPushConstantRanges = &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0,
1649                                                    sizeof(struct build_primitive_constants)},
1650   };
1651
1652   result = radv_CreatePipelineLayout(radv_device_to_handle(device), &leaf_pl_create_info,
1653                                      &device->meta_state.alloc,
1654                                      &device->meta_state.accel_struct_build.leaf_p_layout);
1655   if (result != VK_SUCCESS)
1656      goto fail;
1657
1658   VkPipelineShaderStageCreateInfo leaf_shader_stage = {
1659      .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1660      .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1661      .module = vk_shader_module_handle_from_nir(leaf_cs),
1662      .pName = "main",
1663      .pSpecializationInfo = NULL,
1664   };
1665
1666   VkComputePipelineCreateInfo leaf_pipeline_info = {
1667      .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1668      .stage = leaf_shader_stage,
1669      .flags = 0,
1670      .layout = device->meta_state.accel_struct_build.leaf_p_layout,
1671   };
1672
1673   result = radv_CreateComputePipelines(
1674      radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1675      &leaf_pipeline_info, NULL, &device->meta_state.accel_struct_build.leaf_pipeline);
1676   if (result != VK_SUCCESS)
1677      goto fail;
1678
1679   const VkPipelineLayoutCreateInfo internal_pl_create_info = {
1680      .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1681      .setLayoutCount = 0,
1682      .pushConstantRangeCount = 1,
1683      .pPushConstantRanges = &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0,
1684                                                    sizeof(struct build_internal_constants)},
1685   };
1686
1687   result = radv_CreatePipelineLayout(radv_device_to_handle(device), &internal_pl_create_info,
1688                                      &device->meta_state.alloc,
1689                                      &device->meta_state.accel_struct_build.internal_p_layout);
1690   if (result != VK_SUCCESS)
1691      goto fail;
1692
1693   VkPipelineShaderStageCreateInfo internal_shader_stage = {
1694      .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1695      .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1696      .module = vk_shader_module_handle_from_nir(internal_cs),
1697      .pName = "main",
1698      .pSpecializationInfo = NULL,
1699   };
1700
1701   VkComputePipelineCreateInfo internal_pipeline_info = {
1702      .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1703      .stage = internal_shader_stage,
1704      .flags = 0,
1705      .layout = device->meta_state.accel_struct_build.internal_p_layout,
1706   };
1707
1708   result = radv_CreateComputePipelines(
1709      radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1710      &internal_pipeline_info, NULL, &device->meta_state.accel_struct_build.internal_pipeline);
1711   if (result != VK_SUCCESS)
1712      goto fail;
1713
1714   const VkPipelineLayoutCreateInfo copy_pl_create_info = {
1715      .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1716      .setLayoutCount = 0,
1717      .pushConstantRangeCount = 1,
1718      .pPushConstantRanges =
1719         &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct copy_constants)},
1720   };
1721
1722   result = radv_CreatePipelineLayout(radv_device_to_handle(device), &copy_pl_create_info,
1723                                      &device->meta_state.alloc,
1724                                      &device->meta_state.accel_struct_build.copy_p_layout);
1725   if (result != VK_SUCCESS)
1726      goto fail;
1727
1728   VkPipelineShaderStageCreateInfo copy_shader_stage = {
1729      .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1730      .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1731      .module = vk_shader_module_handle_from_nir(copy_cs),
1732      .pName = "main",
1733      .pSpecializationInfo = NULL,
1734   };
1735
1736   VkComputePipelineCreateInfo copy_pipeline_info = {
1737      .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1738      .stage = copy_shader_stage,
1739      .flags = 0,
1740      .layout = device->meta_state.accel_struct_build.copy_p_layout,
1741   };
1742
1743   result = radv_CreateComputePipelines(
1744      radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1745      &copy_pipeline_info, NULL, &device->meta_state.accel_struct_build.copy_pipeline);
1746   if (result != VK_SUCCESS)
1747      goto fail;
1748
1749   ralloc_free(copy_cs);
1750   ralloc_free(internal_cs);
1751   ralloc_free(leaf_cs);
1752
1753   return VK_SUCCESS;
1754
1755fail:
1756   radv_device_finish_accel_struct_build_state(device);
1757   ralloc_free(copy_cs);
1758   ralloc_free(internal_cs);
1759   ralloc_free(leaf_cs);
1760   return result;
1761}
1762
1763struct bvh_state {
1764   uint32_t node_offset;
1765   uint32_t node_count;
1766   uint32_t scratch_offset;
1767
1768   uint32_t instance_offset;
1769   uint32_t instance_count;
1770};
1771
1772void
1773radv_CmdBuildAccelerationStructuresKHR(
1774   VkCommandBuffer commandBuffer, uint32_t infoCount,
1775   const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
1776   const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
1777{
1778   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
1779   struct radv_meta_saved_state saved_state;
1780
1781   radv_meta_save(
1782      &saved_state, cmd_buffer,
1783      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
1784   struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state));
1785
1786   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
1787                        cmd_buffer->device->meta_state.accel_struct_build.leaf_pipeline);
1788
1789   for (uint32_t i = 0; i < infoCount; ++i) {
1790      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
1791                       pInfos[i].dstAccelerationStructure);
1792
1793      struct build_primitive_constants prim_consts = {
1794         .node_dst_addr = radv_accel_struct_get_va(accel_struct),
1795         .scratch_addr = pInfos[i].scratchData.deviceAddress,
1796         .dst_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64) + 128,
1797         .dst_scratch_offset = 0,
1798      };
1799      bvh_states[i].node_offset = prim_consts.dst_offset;
1800      bvh_states[i].instance_offset = prim_consts.dst_offset;
1801
1802      for (int inst = 1; inst >= 0; --inst) {
1803         for (unsigned j = 0; j < pInfos[i].geometryCount; ++j) {
1804            const VkAccelerationStructureGeometryKHR *geom =
1805               pInfos[i].pGeometries ? &pInfos[i].pGeometries[j] : pInfos[i].ppGeometries[j];
1806
1807            if ((inst && geom->geometryType != VK_GEOMETRY_TYPE_INSTANCES_KHR) ||
1808                (!inst && geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
1809               continue;
1810
1811            prim_consts.geometry_type = geom->geometryType;
1812            prim_consts.geometry_id = j | (geom->flags << 28);
1813            unsigned prim_size;
1814            switch (geom->geometryType) {
1815            case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
1816               prim_consts.vertex_addr =
1817                  geom->geometry.triangles.vertexData.deviceAddress +
1818                  ppBuildRangeInfos[i][j].firstVertex * geom->geometry.triangles.vertexStride +
1819                  (geom->geometry.triangles.indexType != VK_INDEX_TYPE_NONE_KHR
1820                      ? ppBuildRangeInfos[i][j].primitiveOffset
1821                      : 0);
1822               prim_consts.index_addr = geom->geometry.triangles.indexData.deviceAddress +
1823                                        ppBuildRangeInfos[i][j].primitiveOffset;
1824               prim_consts.transform_addr = geom->geometry.triangles.transformData.deviceAddress +
1825                                            ppBuildRangeInfos[i][j].transformOffset;
1826               prim_consts.vertex_stride = geom->geometry.triangles.vertexStride;
1827               prim_consts.vertex_format = geom->geometry.triangles.vertexFormat;
1828               prim_consts.index_format = geom->geometry.triangles.indexType;
1829               prim_size = 64;
1830               break;
1831            case VK_GEOMETRY_TYPE_AABBS_KHR:
1832               prim_consts.aabb_addr =
1833                  geom->geometry.aabbs.data.deviceAddress + ppBuildRangeInfos[i][j].primitiveOffset;
1834               prim_consts.aabb_stride = geom->geometry.aabbs.stride;
1835               prim_size = 64;
1836               break;
1837            case VK_GEOMETRY_TYPE_INSTANCES_KHR:
1838               prim_consts.instance_data = geom->geometry.instances.data.deviceAddress;
1839               prim_consts.array_of_pointers = geom->geometry.instances.arrayOfPointers ? 1 : 0;
1840               prim_size = 128;
1841               bvh_states[i].instance_count += ppBuildRangeInfos[i][j].primitiveCount;
1842               break;
1843            default:
1844               unreachable("Unknown geometryType");
1845            }
1846
1847            radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
1848                                  cmd_buffer->device->meta_state.accel_struct_build.leaf_p_layout,
1849                                  VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(prim_consts),
1850                                  &prim_consts);
1851            radv_unaligned_dispatch(cmd_buffer, ppBuildRangeInfos[i][j].primitiveCount, 1, 1);
1852            prim_consts.dst_offset += prim_size * ppBuildRangeInfos[i][j].primitiveCount;
1853            prim_consts.dst_scratch_offset += 4 * ppBuildRangeInfos[i][j].primitiveCount;
1854         }
1855      }
1856      bvh_states[i].node_offset = prim_consts.dst_offset;
1857      bvh_states[i].node_count = prim_consts.dst_scratch_offset / 4;
1858   }
1859
1860   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
1861                        cmd_buffer->device->meta_state.accel_struct_build.internal_pipeline);
1862   bool progress = true;
1863   for (unsigned iter = 0; progress; ++iter) {
1864      progress = false;
1865      for (uint32_t i = 0; i < infoCount; ++i) {
1866         RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
1867                          pInfos[i].dstAccelerationStructure);
1868
1869         if (iter && bvh_states[i].node_count == 1)
1870            continue;
1871
1872         if (!progress) {
1873            cmd_buffer->state.flush_bits |=
1874               RADV_CMD_FLAG_CS_PARTIAL_FLUSH |
1875               radv_src_access_flush(cmd_buffer, VK_ACCESS_SHADER_WRITE_BIT, NULL) |
1876               radv_dst_access_flush(cmd_buffer,
1877                                     VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, NULL);
1878         }
1879         progress = true;
1880         uint32_t dst_node_count = MAX2(1, DIV_ROUND_UP(bvh_states[i].node_count, 4));
1881         bool final_iter = dst_node_count == 1;
1882         uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
1883         uint32_t dst_scratch_offset = src_scratch_offset ? 0 : bvh_states[i].node_count * 4;
1884         uint32_t dst_node_offset = bvh_states[i].node_offset;
1885         if (final_iter)
1886            dst_node_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64);
1887
1888         const struct build_internal_constants consts = {
1889            .node_dst_addr = radv_accel_struct_get_va(accel_struct),
1890            .scratch_addr = pInfos[i].scratchData.deviceAddress,
1891            .dst_offset = dst_node_offset,
1892            .dst_scratch_offset = dst_scratch_offset,
1893            .src_scratch_offset = src_scratch_offset,
1894            .fill_header = bvh_states[i].node_count | (final_iter ? 0x80000000U : 0),
1895         };
1896
1897         radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
1898                               cmd_buffer->device->meta_state.accel_struct_build.internal_p_layout,
1899                               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
1900         radv_unaligned_dispatch(cmd_buffer, dst_node_count, 1, 1);
1901         if (!final_iter)
1902            bvh_states[i].node_offset += dst_node_count * 128;
1903         bvh_states[i].node_count = dst_node_count;
1904         bvh_states[i].scratch_offset = dst_scratch_offset;
1905      }
1906   }
1907   for (uint32_t i = 0; i < infoCount; ++i) {
1908      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
1909                       pInfos[i].dstAccelerationStructure);
1910      const size_t base = offsetof(struct radv_accel_struct_header, compacted_size);
1911      struct radv_accel_struct_header header;
1912
1913      header.instance_offset = bvh_states[i].instance_offset;
1914      header.instance_count = bvh_states[i].instance_count;
1915      header.compacted_size = bvh_states[i].node_offset;
1916
1917      /* 16 bytes per invocation, 64 invocations per workgroup */
1918      header.copy_dispatch_size[0] = DIV_ROUND_UP(header.compacted_size, 16 * 64);
1919      header.copy_dispatch_size[1] = 1;
1920      header.copy_dispatch_size[2] = 1;
1921
1922      header.serialization_size =
1923         header.compacted_size + align(sizeof(struct radv_accel_struct_serialization_header) +
1924                                          sizeof(uint64_t) * header.instance_count,
1925                                       128);
1926
1927      radv_update_buffer_cp(cmd_buffer,
1928                            radv_buffer_get_va(accel_struct->bo) + accel_struct->mem_offset + base,
1929                            (const char *)&header + base, sizeof(header) - base);
1930   }
1931   free(bvh_states);
1932   radv_meta_restore(&saved_state, cmd_buffer);
1933}
1934
1935void
1936radv_CmdCopyAccelerationStructureKHR(VkCommandBuffer commandBuffer,
1937                                     const VkCopyAccelerationStructureInfoKHR *pInfo)
1938{
1939   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
1940   RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
1941   RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
1942   struct radv_meta_saved_state saved_state;
1943
1944   radv_meta_save(
1945      &saved_state, cmd_buffer,
1946      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
1947
1948   uint64_t src_addr = radv_accel_struct_get_va(src);
1949   uint64_t dst_addr = radv_accel_struct_get_va(dst);
1950
1951   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
1952                        cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
1953
1954   const struct copy_constants consts = {
1955      .src_addr = src_addr,
1956      .dst_addr = dst_addr,
1957      .mode = COPY_MODE_COPY,
1958   };
1959
1960   radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
1961                         cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
1962                         VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
1963
1964   radv_indirect_dispatch(cmd_buffer, src->bo,
1965                          src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
1966   radv_meta_restore(&saved_state, cmd_buffer);
1967}
1968
1969void
1970radv_GetDeviceAccelerationStructureCompatibilityKHR(
1971   VkDevice _device, const VkAccelerationStructureVersionInfoKHR *pVersionInfo,
1972   VkAccelerationStructureCompatibilityKHR *pCompatibility)
1973{
1974   RADV_FROM_HANDLE(radv_device, device, _device);
1975   uint8_t zero[VK_UUID_SIZE] = {
1976      0,
1977   };
1978   bool compat =
1979      memcmp(pVersionInfo->pVersionData, device->physical_device->driver_uuid, VK_UUID_SIZE) == 0 &&
1980      memcmp(pVersionInfo->pVersionData + VK_UUID_SIZE, zero, VK_UUID_SIZE) == 0;
1981   *pCompatibility = compat ? VK_ACCELERATION_STRUCTURE_COMPATIBILITY_COMPATIBLE_KHR
1982                            : VK_ACCELERATION_STRUCTURE_COMPATIBILITY_INCOMPATIBLE_KHR;
1983}
1984
1985VkResult
1986radv_CopyMemoryToAccelerationStructureKHR(VkDevice _device,
1987                                          VkDeferredOperationKHR deferredOperation,
1988                                          const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
1989{
1990   RADV_FROM_HANDLE(radv_device, device, _device);
1991   RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->dst);
1992
1993   char *base = device->ws->buffer_map(accel_struct->bo);
1994   if (!base)
1995      return VK_ERROR_OUT_OF_HOST_MEMORY;
1996
1997   base += accel_struct->mem_offset;
1998   const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
1999
2000   const char *src = pInfo->src.hostAddress;
2001   struct radv_accel_struct_serialization_header *src_header = (void *)src;
2002   src += sizeof(*src_header) + sizeof(uint64_t) * src_header->instance_count;
2003
2004   memcpy(base, src, src_header->compacted_size);
2005
2006   for (unsigned i = 0; i < src_header->instance_count; ++i) {
2007      uint64_t *p = (uint64_t *)(base + i * 128 + header->instance_offset);
2008      *p = (*p & 63) | src_header->instances[i];
2009   }
2010
2011   device->ws->buffer_unmap(accel_struct->bo);
2012   return VK_SUCCESS;
2013}
2014
2015VkResult
2016radv_CopyAccelerationStructureToMemoryKHR(VkDevice _device,
2017                                          VkDeferredOperationKHR deferredOperation,
2018                                          const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2019{
2020   RADV_FROM_HANDLE(radv_device, device, _device);
2021   RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->src);
2022
2023   const char *base = device->ws->buffer_map(accel_struct->bo);
2024   if (!base)
2025      return VK_ERROR_OUT_OF_HOST_MEMORY;
2026
2027   base += accel_struct->mem_offset;
2028   const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
2029
2030   char *dst = pInfo->dst.hostAddress;
2031   struct radv_accel_struct_serialization_header *dst_header = (void *)dst;
2032   dst += sizeof(*dst_header) + sizeof(uint64_t) * header->instance_count;
2033
2034   memcpy(dst_header->driver_uuid, device->physical_device->driver_uuid, VK_UUID_SIZE);
2035   memset(dst_header->accel_struct_compat, 0, VK_UUID_SIZE);
2036
2037   dst_header->serialization_size = header->serialization_size;
2038   dst_header->compacted_size = header->compacted_size;
2039   dst_header->instance_count = header->instance_count;
2040
2041   memcpy(dst, base, header->compacted_size);
2042
2043   for (unsigned i = 0; i < header->instance_count; ++i) {
2044      dst_header->instances[i] =
2045         *(const uint64_t *)(base + i * 128 + header->instance_offset) & ~63ull;
2046   }
2047
2048   device->ws->buffer_unmap(accel_struct->bo);
2049   return VK_SUCCESS;
2050}
2051
2052void
2053radv_CmdCopyMemoryToAccelerationStructureKHR(
2054   VkCommandBuffer commandBuffer, const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
2055{
2056   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2057   RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
2058   struct radv_meta_saved_state saved_state;
2059
2060   radv_meta_save(
2061      &saved_state, cmd_buffer,
2062      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2063
2064   uint64_t dst_addr = radv_accel_struct_get_va(dst);
2065
2066   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2067                        cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2068
2069   const struct copy_constants consts = {
2070      .src_addr = pInfo->src.deviceAddress,
2071      .dst_addr = dst_addr,
2072      .mode = COPY_MODE_DESERIALIZE,
2073   };
2074
2075   radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2076                         cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2077                         VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2078
2079   radv_CmdDispatch(commandBuffer, 512, 1, 1);
2080   radv_meta_restore(&saved_state, cmd_buffer);
2081}
2082
2083void
2084radv_CmdCopyAccelerationStructureToMemoryKHR(
2085   VkCommandBuffer commandBuffer, const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2086{
2087   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2088   RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
2089   struct radv_meta_saved_state saved_state;
2090
2091   radv_meta_save(
2092      &saved_state, cmd_buffer,
2093      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2094
2095   uint64_t src_addr = radv_accel_struct_get_va(src);
2096
2097   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2098                        cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2099
2100   const struct copy_constants consts = {
2101      .src_addr = src_addr,
2102      .dst_addr = pInfo->dst.deviceAddress,
2103      .mode = COPY_MODE_SERIALIZE,
2104   };
2105
2106   radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2107                         cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2108                         VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2109
2110   radv_indirect_dispatch(cmd_buffer, src->bo,
2111                          src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
2112   radv_meta_restore(&saved_state, cmd_buffer);
2113
2114   /* Set the header of the serialized data. */
2115   uint8_t header_data[2 * VK_UUID_SIZE] = {0};
2116   memcpy(header_data, cmd_buffer->device->physical_device->driver_uuid, VK_UUID_SIZE);
2117
2118   radv_update_buffer_cp(cmd_buffer, pInfo->dst.deviceAddress, header_data, sizeof(header_data));
2119}
2120