1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "dxil_nir.h"
25
26#include "nir_builder.h"
27#include "nir_deref.h"
28#include "nir_to_dxil.h"
29#include "util/u_math.h"
30
31static void
32cl_type_size_align(const struct glsl_type *type, unsigned *size,
33                   unsigned *align)
34{
35   *size = glsl_get_cl_size(type);
36   *align = glsl_get_cl_alignment(type);
37}
38
39static void
40extract_comps_from_vec32(nir_builder *b, nir_ssa_def *vec32,
41                         unsigned dst_bit_size,
42                         nir_ssa_def **dst_comps,
43                         unsigned num_dst_comps)
44{
45   unsigned step = DIV_ROUND_UP(dst_bit_size, 32);
46   unsigned comps_per32b = 32 / dst_bit_size;
47   nir_ssa_def *tmp;
48
49   for (unsigned i = 0; i < vec32->num_components; i += step) {
50      switch (dst_bit_size) {
51      case 64:
52         tmp = nir_pack_64_2x32_split(b, nir_channel(b, vec32, i),
53                                         nir_channel(b, vec32, i + 1));
54         dst_comps[i / 2] = tmp;
55         break;
56      case 32:
57         dst_comps[i] = nir_channel(b, vec32, i);
58         break;
59      case 16:
60      case 8: {
61         unsigned dst_offs = i * comps_per32b;
62
63         tmp = nir_unpack_bits(b, nir_channel(b, vec32, i), dst_bit_size);
64         for (unsigned j = 0; j < comps_per32b && dst_offs + j < num_dst_comps; j++)
65            dst_comps[dst_offs + j] = nir_channel(b, tmp, j);
66         }
67
68         break;
69      }
70   }
71}
72
73static nir_ssa_def *
74load_comps_to_vec32(nir_builder *b, unsigned src_bit_size,
75                    nir_ssa_def **src_comps, unsigned num_src_comps)
76{
77   unsigned num_vec32comps = DIV_ROUND_UP(num_src_comps * src_bit_size, 32);
78   unsigned step = DIV_ROUND_UP(src_bit_size, 32);
79   unsigned comps_per32b = 32 / src_bit_size;
80   nir_ssa_def *vec32comps[4];
81
82   for (unsigned i = 0; i < num_vec32comps; i += step) {
83      switch (src_bit_size) {
84      case 64:
85         vec32comps[i] = nir_unpack_64_2x32_split_x(b, src_comps[i / 2]);
86         vec32comps[i + 1] = nir_unpack_64_2x32_split_y(b, src_comps[i / 2]);
87         break;
88      case 32:
89         vec32comps[i] = src_comps[i];
90         break;
91      case 16:
92      case 8: {
93         unsigned src_offs = i * comps_per32b;
94
95         vec32comps[i] = nir_u2u32(b, src_comps[src_offs]);
96         for (unsigned j = 1; j < comps_per32b && src_offs + j < num_src_comps; j++) {
97            nir_ssa_def *tmp = nir_ishl(b, nir_u2u32(b, src_comps[src_offs + j]),
98                                           nir_imm_int(b, j * src_bit_size));
99            vec32comps[i] = nir_ior(b, vec32comps[i], tmp);
100         }
101         break;
102      }
103      }
104   }
105
106   return nir_vec(b, vec32comps, num_vec32comps);
107}
108
109static nir_ssa_def *
110build_load_ptr_dxil(nir_builder *b, nir_deref_instr *deref, nir_ssa_def *idx)
111{
112   return nir_load_ptr_dxil(b, 1, 32, &deref->dest.ssa, idx);
113}
114
115static bool
116lower_load_deref(nir_builder *b, nir_intrinsic_instr *intr)
117{
118   assert(intr->dest.is_ssa);
119
120   b->cursor = nir_before_instr(&intr->instr);
121
122   nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
123   if (!nir_deref_mode_is(deref, nir_var_shader_temp))
124      return false;
125   nir_ssa_def *ptr = nir_u2u32(b, nir_build_deref_offset(b, deref, cl_type_size_align));
126   nir_ssa_def *offset = nir_iand(b, ptr, nir_inot(b, nir_imm_int(b, 3)));
127
128   assert(intr->dest.is_ssa);
129   unsigned num_components = nir_dest_num_components(intr->dest);
130   unsigned bit_size = nir_dest_bit_size(intr->dest);
131   unsigned load_size = MAX2(32, bit_size);
132   unsigned num_bits = num_components * bit_size;
133   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
134   unsigned comp_idx = 0;
135
136   nir_deref_path path;
137   nir_deref_path_init(&path, deref, NULL);
138   nir_ssa_def *base_idx = nir_ishr(b, offset, nir_imm_int(b, 2 /* log2(32 / 8) */));
139
140   /* Split loads into 32-bit chunks */
141   for (unsigned i = 0; i < num_bits; i += load_size) {
142      unsigned subload_num_bits = MIN2(num_bits - i, load_size);
143      nir_ssa_def *idx = nir_iadd(b, base_idx, nir_imm_int(b, i / 32));
144      nir_ssa_def *vec32 = build_load_ptr_dxil(b, path.path[0], idx);
145
146      if (load_size == 64) {
147         idx = nir_iadd(b, idx, nir_imm_int(b, 1));
148         vec32 = nir_vec2(b, vec32,
149                             build_load_ptr_dxil(b, path.path[0], idx));
150      }
151
152      /* If we have 2 bytes or less to load we need to adjust the u32 value so
153       * we can always extract the LSB.
154       */
155      if (subload_num_bits <= 16) {
156         nir_ssa_def *shift = nir_imul(b, nir_iand(b, ptr, nir_imm_int(b, 3)),
157                                          nir_imm_int(b, 8));
158         vec32 = nir_ushr(b, vec32, shift);
159      }
160
161      /* And now comes the pack/unpack step to match the original type. */
162      extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
163                               subload_num_bits / bit_size);
164      comp_idx += subload_num_bits / bit_size;
165   }
166
167   nir_deref_path_finish(&path);
168   assert(comp_idx == num_components);
169   nir_ssa_def *result = nir_vec(b, comps, num_components);
170   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
171   nir_instr_remove(&intr->instr);
172   return true;
173}
174
175static nir_ssa_def *
176ubo_load_select_32b_comps(nir_builder *b, nir_ssa_def *vec32,
177                          nir_ssa_def *offset, unsigned num_bytes)
178{
179   assert(num_bytes == 16 || num_bytes == 12 || num_bytes == 8 ||
180          num_bytes == 4 || num_bytes == 3 || num_bytes == 2 ||
181          num_bytes == 1);
182   assert(vec32->num_components == 4);
183
184   /* 16 and 12 byte types are always aligned on 16 bytes. */
185   if (num_bytes > 8)
186      return vec32;
187
188   nir_ssa_def *comps[4];
189   nir_ssa_def *cond;
190
191   for (unsigned i = 0; i < 4; i++)
192      comps[i] = nir_channel(b, vec32, i);
193
194   /* If we have 8bytes or less to load, select which half the vec4 should
195    * be used.
196    */
197   cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x8)),
198                                 nir_imm_int(b, 0));
199
200   comps[0] = nir_bcsel(b, cond, comps[2], comps[0]);
201   comps[1] = nir_bcsel(b, cond, comps[3], comps[1]);
202
203   /* Thanks to the CL alignment constraints, if we want 8 bytes we're done. */
204   if (num_bytes == 8)
205      return nir_vec(b, comps, 2);
206
207   /* 4 bytes or less needed, select which of the 32bit component should be
208    * used and return it. The sub-32bit split is handled in
209    * extract_comps_from_vec32().
210    */
211   cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x4)),
212                                 nir_imm_int(b, 0));
213   return nir_bcsel(b, cond, comps[1], comps[0]);
214}
215
216nir_ssa_def *
217build_load_ubo_dxil(nir_builder *b, nir_ssa_def *buffer,
218                    nir_ssa_def *offset, unsigned num_components,
219                    unsigned bit_size)
220{
221   nir_ssa_def *idx = nir_ushr(b, offset, nir_imm_int(b, 4));
222   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
223   unsigned num_bits = num_components * bit_size;
224   unsigned comp_idx = 0;
225
226   /* We need to split loads in 16byte chunks because that's the
227    * granularity of cBufferLoadLegacy().
228    */
229   for (unsigned i = 0; i < num_bits; i += (16 * 8)) {
230      /* For each 16byte chunk (or smaller) we generate a 32bit ubo vec
231       * load.
232       */
233      unsigned subload_num_bits = MIN2(num_bits - i, 16 * 8);
234      nir_ssa_def *vec32 =
235         nir_load_ubo_dxil(b, 4, 32, buffer, nir_iadd(b, idx, nir_imm_int(b, i / (16 * 8))));
236
237      /* First re-arrange the vec32 to account for intra 16-byte offset. */
238      vec32 = ubo_load_select_32b_comps(b, vec32, offset, subload_num_bits / 8);
239
240      /* If we have 2 bytes or less to load we need to adjust the u32 value so
241       * we can always extract the LSB.
242       */
243      if (subload_num_bits <= 16) {
244         nir_ssa_def *shift = nir_imul(b, nir_iand(b, offset,
245                                                      nir_imm_int(b, 3)),
246                                          nir_imm_int(b, 8));
247         vec32 = nir_ushr(b, vec32, shift);
248      }
249
250      /* And now comes the pack/unpack step to match the original type. */
251      extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
252                               subload_num_bits / bit_size);
253      comp_idx += subload_num_bits / bit_size;
254   }
255
256   assert(comp_idx == num_components);
257   return nir_vec(b, comps, num_components);
258}
259
260static bool
261lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
262{
263   assert(intr->dest.is_ssa);
264   assert(intr->src[0].is_ssa);
265   assert(intr->src[1].is_ssa);
266
267   b->cursor = nir_before_instr(&intr->instr);
268
269   nir_ssa_def *buffer = intr->src[0].ssa;
270   nir_ssa_def *offset = nir_iand(b, intr->src[1].ssa, nir_imm_int(b, ~3));
271   enum gl_access_qualifier access = nir_intrinsic_access(intr);
272   unsigned bit_size = nir_dest_bit_size(intr->dest);
273   unsigned num_components = nir_dest_num_components(intr->dest);
274   unsigned num_bits = num_components * bit_size;
275
276   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
277   unsigned comp_idx = 0;
278
279   /* We need to split loads in 16byte chunks because that's the optimal
280    * granularity of bufferLoad(). Minimum alignment is 4byte, which saves
281    * from us from extra complexity to extract >= 32 bit components.
282    */
283   for (unsigned i = 0; i < num_bits; i += 4 * 32) {
284      /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
285       * load.
286       */
287      unsigned subload_num_bits = MIN2(num_bits - i, 4 * 32);
288
289      /* The number of components to store depends on the number of bytes. */
290      nir_ssa_def *vec32 =
291         nir_load_ssbo(b, DIV_ROUND_UP(subload_num_bits, 32), 32,
292                       buffer, nir_iadd(b, offset, nir_imm_int(b, i / 8)),
293                       .align_mul = 4,
294                       .align_offset = 0,
295                       .access = access);
296
297      /* If we have 2 bytes or less to load we need to adjust the u32 value so
298       * we can always extract the LSB.
299       */
300      if (subload_num_bits <= 16) {
301         nir_ssa_def *shift = nir_imul(b, nir_iand(b, intr->src[1].ssa, nir_imm_int(b, 3)),
302                                          nir_imm_int(b, 8));
303         vec32 = nir_ushr(b, vec32, shift);
304      }
305
306      /* And now comes the pack/unpack step to match the original type. */
307      extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
308                               subload_num_bits / bit_size);
309      comp_idx += subload_num_bits / bit_size;
310   }
311
312   assert(comp_idx == num_components);
313   nir_ssa_def *result = nir_vec(b, comps, num_components);
314   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
315   nir_instr_remove(&intr->instr);
316   return true;
317}
318
319static bool
320lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
321{
322   b->cursor = nir_before_instr(&intr->instr);
323
324   assert(intr->src[0].is_ssa);
325   assert(intr->src[1].is_ssa);
326   assert(intr->src[2].is_ssa);
327
328   nir_ssa_def *val = intr->src[0].ssa;
329   nir_ssa_def *buffer = intr->src[1].ssa;
330   nir_ssa_def *offset = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, ~3));
331
332   unsigned bit_size = val->bit_size;
333   unsigned num_components = val->num_components;
334   unsigned num_bits = num_components * bit_size;
335
336   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
337   unsigned comp_idx = 0;
338
339   for (unsigned i = 0; i < num_components; i++)
340      comps[i] = nir_channel(b, val, i);
341
342   /* We split stores in 16byte chunks because that's the optimal granularity
343    * of bufferStore(). Minimum alignment is 4byte, which saves from us from
344    * extra complexity to store >= 32 bit components.
345    */
346   for (unsigned i = 0; i < num_bits; i += 4 * 32) {
347      /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
348       * store.
349       */
350      unsigned substore_num_bits = MIN2(num_bits - i, 4 * 32);
351      nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, i / 8));
352      nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
353                                               substore_num_bits / bit_size);
354      nir_intrinsic_instr *store;
355
356      if (substore_num_bits < 32) {
357         nir_ssa_def *mask = nir_imm_int(b, (1 << substore_num_bits) - 1);
358
359        /* If we have 16 bits or less to store we need to place them
360         * correctly in the u32 component. Anything greater than 16 bits
361         * (including uchar3) is naturally aligned on 32bits.
362         */
363         if (substore_num_bits <= 16) {
364            nir_ssa_def *pos = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, 3));
365            nir_ssa_def *shift = nir_imul_imm(b, pos, 8);
366
367            vec32 = nir_ishl(b, vec32, shift);
368            mask = nir_ishl(b, mask, shift);
369         }
370
371         store = nir_intrinsic_instr_create(b->shader,
372                                            nir_intrinsic_store_ssbo_masked_dxil);
373         store->src[0] = nir_src_for_ssa(vec32);
374         store->src[1] = nir_src_for_ssa(nir_inot(b, mask));
375         store->src[2] = nir_src_for_ssa(buffer);
376         store->src[3] = nir_src_for_ssa(local_offset);
377      } else {
378         store = nir_intrinsic_instr_create(b->shader,
379                                            nir_intrinsic_store_ssbo);
380         store->src[0] = nir_src_for_ssa(vec32);
381         store->src[1] = nir_src_for_ssa(buffer);
382         store->src[2] = nir_src_for_ssa(local_offset);
383
384         nir_intrinsic_set_align(store, 4, 0);
385      }
386
387      /* The number of components to store depends on the number of bits. */
388      store->num_components = DIV_ROUND_UP(substore_num_bits, 32);
389      nir_builder_instr_insert(b, &store->instr);
390      comp_idx += substore_num_bits / bit_size;
391   }
392
393   nir_instr_remove(&intr->instr);
394   return true;
395}
396
397static void
398lower_load_vec32(nir_builder *b, nir_ssa_def *index, unsigned num_comps, nir_ssa_def **comps, nir_intrinsic_op op)
399{
400   for (unsigned i = 0; i < num_comps; i++) {
401      nir_intrinsic_instr *load =
402         nir_intrinsic_instr_create(b->shader, op);
403
404      load->num_components = 1;
405      load->src[0] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
406      nir_ssa_dest_init(&load->instr, &load->dest, 1, 32, NULL);
407      nir_builder_instr_insert(b, &load->instr);
408      comps[i] = &load->dest.ssa;
409   }
410}
411
412static bool
413lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
414{
415   assert(intr->dest.is_ssa);
416   unsigned bit_size = nir_dest_bit_size(intr->dest);
417   unsigned num_components = nir_dest_num_components(intr->dest);
418   unsigned num_bits = num_components * bit_size;
419
420   b->cursor = nir_before_instr(&intr->instr);
421   nir_intrinsic_op op = intr->intrinsic;
422
423   assert(intr->src[0].is_ssa);
424   nir_ssa_def *offset = intr->src[0].ssa;
425   if (op == nir_intrinsic_load_shared) {
426      offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
427      op = nir_intrinsic_load_shared_dxil;
428   } else {
429      offset = nir_u2u32(b, offset);
430      op = nir_intrinsic_load_scratch_dxil;
431   }
432   nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
433   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
434   nir_ssa_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
435
436   /* We need to split loads in 32-bit accesses because the buffer
437    * is an i32 array and DXIL does not support type casts.
438    */
439   unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
440   lower_load_vec32(b, index, num_32bit_comps, comps_32bit, op);
441   unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
442
443   for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
444      unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
445      unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
446      nir_ssa_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
447
448      /* If we have 16 bits or less to load we need to adjust the u32 value so
449       * we can always extract the LSB.
450       */
451      if (num_bits <= 16) {
452         nir_ssa_def *shift =
453            nir_imul(b, nir_iand(b, offset, nir_imm_int(b, 3)),
454                        nir_imm_int(b, 8));
455         vec32 = nir_ushr(b, vec32, shift);
456      }
457
458      /* And now comes the pack/unpack step to match the original type. */
459      unsigned dest_index = i * 32 / bit_size;
460      extract_comps_from_vec32(b, vec32, bit_size, &comps[dest_index], num_dest_comps);
461   }
462
463   nir_ssa_def *result = nir_vec(b, comps, num_components);
464   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
465   nir_instr_remove(&intr->instr);
466
467   return true;
468}
469
470static void
471lower_store_vec32(nir_builder *b, nir_ssa_def *index, nir_ssa_def *vec32, nir_intrinsic_op op)
472{
473
474   for (unsigned i = 0; i < vec32->num_components; i++) {
475      nir_intrinsic_instr *store =
476         nir_intrinsic_instr_create(b->shader, op);
477
478      store->src[0] = nir_src_for_ssa(nir_channel(b, vec32, i));
479      store->src[1] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
480      store->num_components = 1;
481      nir_builder_instr_insert(b, &store->instr);
482   }
483}
484
485static void
486lower_masked_store_vec32(nir_builder *b, nir_ssa_def *offset, nir_ssa_def *index,
487                         nir_ssa_def *vec32, unsigned num_bits, nir_intrinsic_op op)
488{
489   nir_ssa_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
490
491   /* If we have 16 bits or less to store we need to place them correctly in
492    * the u32 component. Anything greater than 16 bits (including uchar3) is
493    * naturally aligned on 32bits.
494    */
495   if (num_bits <= 16) {
496      nir_ssa_def *shift =
497         nir_imul_imm(b, nir_iand(b, offset, nir_imm_int(b, 3)), 8);
498
499      vec32 = nir_ishl(b, vec32, shift);
500      mask = nir_ishl(b, mask, shift);
501   }
502
503   if (op == nir_intrinsic_store_shared_dxil) {
504      /* Use the dedicated masked intrinsic */
505      nir_store_shared_masked_dxil(b, vec32, nir_inot(b, mask), index);
506   } else {
507      /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
508      nir_ssa_def *load = nir_load_scratch_dxil(b, 1, 32, index);
509
510      nir_ssa_def *new_val = nir_ior(b, vec32,
511                                     nir_iand(b,
512                                              nir_inot(b, mask),
513                                              load));
514
515      lower_store_vec32(b, index, new_val, op);
516   }
517}
518
519static bool
520lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
521{
522   assert(intr->src[0].is_ssa);
523   unsigned num_components = nir_src_num_components(intr->src[0]);
524   unsigned bit_size = nir_src_bit_size(intr->src[0]);
525   unsigned num_bits = num_components * bit_size;
526
527   b->cursor = nir_before_instr(&intr->instr);
528   nir_intrinsic_op op = intr->intrinsic;
529
530   nir_ssa_def *offset = intr->src[1].ssa;
531   if (op == nir_intrinsic_store_shared) {
532      offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
533      op = nir_intrinsic_store_shared_dxil;
534   } else {
535      offset = nir_u2u32(b, offset);
536      op = nir_intrinsic_store_scratch_dxil;
537   }
538   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
539
540   unsigned comp_idx = 0;
541   for (unsigned i = 0; i < num_components; i++)
542      comps[i] = nir_channel(b, intr->src[0].ssa, i);
543
544   for (unsigned i = 0; i < num_bits; i += 4 * 32) {
545      /* For each 4byte chunk (or smaller) we generate a 32bit scalar store.
546       */
547      unsigned substore_num_bits = MIN2(num_bits - i, 4 * 32);
548      nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, i / 8));
549      nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
550                                               substore_num_bits / bit_size);
551      nir_ssa_def *index = nir_ushr(b, local_offset, nir_imm_int(b, 2));
552
553      /* For anything less than 32bits we need to use the masked version of the
554       * intrinsic to preserve data living in the same 32bit slot.
555       */
556      if (num_bits < 32) {
557         lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, op);
558      } else {
559         lower_store_vec32(b, index, vec32, op);
560      }
561
562      comp_idx += substore_num_bits / bit_size;
563   }
564
565   nir_instr_remove(&intr->instr);
566
567   return true;
568}
569
570static void
571ubo_to_temp_patch_deref_mode(nir_deref_instr *deref)
572{
573   deref->modes = nir_var_shader_temp;
574   nir_foreach_use(use_src, &deref->dest.ssa) {
575      if (use_src->parent_instr->type != nir_instr_type_deref)
576	 continue;
577
578      nir_deref_instr *parent = nir_instr_as_deref(use_src->parent_instr);
579      ubo_to_temp_patch_deref_mode(parent);
580   }
581}
582
583static void
584ubo_to_temp_update_entry(nir_deref_instr *deref, struct hash_entry *he)
585{
586   assert(nir_deref_mode_is(deref, nir_var_mem_constant));
587   assert(deref->dest.is_ssa);
588   assert(he->data);
589
590   nir_foreach_use(use_src, &deref->dest.ssa) {
591      if (use_src->parent_instr->type == nir_instr_type_deref) {
592         ubo_to_temp_update_entry(nir_instr_as_deref(use_src->parent_instr), he);
593      } else if (use_src->parent_instr->type == nir_instr_type_intrinsic) {
594         nir_intrinsic_instr *intr = nir_instr_as_intrinsic(use_src->parent_instr);
595         if (intr->intrinsic != nir_intrinsic_load_deref)
596            he->data = NULL;
597      } else {
598         he->data = NULL;
599      }
600
601      if (!he->data)
602         break;
603   }
604}
605
606bool
607dxil_nir_lower_ubo_to_temp(nir_shader *nir)
608{
609   struct hash_table *ubo_to_temp = _mesa_pointer_hash_table_create(NULL);
610   bool progress = false;
611
612   /* First pass: collect all UBO accesses that could be turned into
613    * shader temp accesses.
614    */
615   foreach_list_typed(nir_function, func, node, &nir->functions) {
616      if (!func->is_entrypoint)
617         continue;
618      assert(func->impl);
619
620      nir_foreach_block(block, func->impl) {
621         nir_foreach_instr_safe(instr, block) {
622            if (instr->type != nir_instr_type_deref)
623               continue;
624
625            nir_deref_instr *deref = nir_instr_as_deref(instr);
626            if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
627                deref->deref_type != nir_deref_type_var)
628                  continue;
629
630            struct hash_entry *he =
631               _mesa_hash_table_search(ubo_to_temp, deref->var);
632
633            if (!he)
634               he = _mesa_hash_table_insert(ubo_to_temp, deref->var, deref->var);
635
636            if (!he->data)
637               continue;
638
639            ubo_to_temp_update_entry(deref, he);
640         }
641      }
642   }
643
644   hash_table_foreach(ubo_to_temp, he) {
645      nir_variable *var = he->data;
646
647      if (!var)
648         continue;
649
650      /* Change the variable mode. */
651      var->data.mode = nir_var_shader_temp;
652
653      /* Make sure the variable has a name.
654       * DXIL variables must have names.
655       */
656      if (!var->name)
657         var->name = ralloc_asprintf(nir, "global_%d", exec_list_length(&nir->variables));
658
659      progress = true;
660   }
661   _mesa_hash_table_destroy(ubo_to_temp, NULL);
662
663   /* Second pass: patch all derefs that were accessing the converted UBOs
664    * variables.
665    */
666   foreach_list_typed(nir_function, func, node, &nir->functions) {
667      if (!func->is_entrypoint)
668         continue;
669      assert(func->impl);
670
671      nir_foreach_block(block, func->impl) {
672         nir_foreach_instr_safe(instr, block) {
673            if (instr->type != nir_instr_type_deref)
674               continue;
675
676            nir_deref_instr *deref = nir_instr_as_deref(instr);
677            if (nir_deref_mode_is(deref, nir_var_mem_constant) &&
678                deref->deref_type == nir_deref_type_var &&
679                deref->var->data.mode == nir_var_shader_temp)
680               ubo_to_temp_patch_deref_mode(deref);
681         }
682      }
683   }
684
685   return progress;
686}
687
688static bool
689lower_load_ubo(nir_builder *b, nir_intrinsic_instr *intr)
690{
691   assert(intr->dest.is_ssa);
692   assert(intr->src[0].is_ssa);
693   assert(intr->src[1].is_ssa);
694
695   b->cursor = nir_before_instr(&intr->instr);
696
697   nir_ssa_def *result =
698      build_load_ubo_dxil(b, intr->src[0].ssa, intr->src[1].ssa,
699                             nir_dest_num_components(intr->dest),
700                             nir_dest_bit_size(intr->dest));
701
702   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
703   nir_instr_remove(&intr->instr);
704   return true;
705}
706
707bool
708dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir)
709{
710   bool progress = false;
711
712   foreach_list_typed(nir_function, func, node, &nir->functions) {
713      if (!func->is_entrypoint)
714         continue;
715      assert(func->impl);
716
717      nir_builder b;
718      nir_builder_init(&b, func->impl);
719
720      nir_foreach_block(block, func->impl) {
721         nir_foreach_instr_safe(instr, block) {
722            if (instr->type != nir_instr_type_intrinsic)
723               continue;
724            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
725
726            switch (intr->intrinsic) {
727            case nir_intrinsic_load_deref:
728               progress |= lower_load_deref(&b, intr);
729               break;
730            case nir_intrinsic_load_shared:
731            case nir_intrinsic_load_scratch:
732               progress |= lower_32b_offset_load(&b, intr);
733               break;
734            case nir_intrinsic_load_ssbo:
735               progress |= lower_load_ssbo(&b, intr);
736               break;
737            case nir_intrinsic_load_ubo:
738               progress |= lower_load_ubo(&b, intr);
739               break;
740            case nir_intrinsic_store_shared:
741            case nir_intrinsic_store_scratch:
742               progress |= lower_32b_offset_store(&b, intr);
743               break;
744            case nir_intrinsic_store_ssbo:
745               progress |= lower_store_ssbo(&b, intr);
746               break;
747            default:
748               break;
749            }
750         }
751      }
752   }
753
754   return progress;
755}
756
757static bool
758lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr,
759                    nir_intrinsic_op dxil_op)
760{
761   b->cursor = nir_before_instr(&intr->instr);
762
763   assert(intr->src[0].is_ssa);
764   nir_ssa_def *offset =
765      nir_iadd(b, intr->src[0].ssa, nir_imm_int(b, nir_intrinsic_base(intr)));
766   nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
767
768   nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b->shader, dxil_op);
769   atomic->src[0] = nir_src_for_ssa(index);
770   assert(intr->src[1].is_ssa);
771   atomic->src[1] = nir_src_for_ssa(intr->src[1].ssa);
772   if (dxil_op == nir_intrinsic_shared_atomic_comp_swap_dxil) {
773      assert(intr->src[2].is_ssa);
774      atomic->src[2] = nir_src_for_ssa(intr->src[2].ssa);
775   }
776   atomic->num_components = 0;
777   nir_ssa_dest_init(&atomic->instr, &atomic->dest, 1, 32, NULL);
778
779   nir_builder_instr_insert(b, &atomic->instr);
780   nir_ssa_def_rewrite_uses(&intr->dest.ssa, &atomic->dest.ssa);
781   nir_instr_remove(&intr->instr);
782   return true;
783}
784
785bool
786dxil_nir_lower_atomics_to_dxil(nir_shader *nir)
787{
788   bool progress = false;
789
790   foreach_list_typed(nir_function, func, node, &nir->functions) {
791      if (!func->is_entrypoint)
792         continue;
793      assert(func->impl);
794
795      nir_builder b;
796      nir_builder_init(&b, func->impl);
797
798      nir_foreach_block(block, func->impl) {
799         nir_foreach_instr_safe(instr, block) {
800            if (instr->type != nir_instr_type_intrinsic)
801               continue;
802            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
803
804            switch (intr->intrinsic) {
805
806#define ATOMIC(op)                                                            \
807  case nir_intrinsic_shared_atomic_##op:                                     \
808     progress |= lower_shared_atomic(&b, intr,                                \
809                                     nir_intrinsic_shared_atomic_##op##_dxil); \
810     break
811
812            ATOMIC(add);
813            ATOMIC(imin);
814            ATOMIC(umin);
815            ATOMIC(imax);
816            ATOMIC(umax);
817            ATOMIC(and);
818            ATOMIC(or);
819            ATOMIC(xor);
820            ATOMIC(exchange);
821            ATOMIC(comp_swap);
822
823#undef ATOMIC
824            default:
825               break;
826            }
827         }
828      }
829   }
830
831   return progress;
832}
833
834static bool
835lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
836{
837   assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
838   assert(deref->deref_type == nir_deref_type_var ||
839          deref->deref_type == nir_deref_type_cast);
840   nir_variable *var = deref->var;
841
842   b->cursor = nir_before_instr(&deref->instr);
843
844   if (deref->deref_type == nir_deref_type_var) {
845      /* We turn all deref_var into deref_cast and build a pointer value based on
846       * the var binding which encodes the UAV id.
847       */
848      nir_ssa_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
849      nir_deref_instr *deref_cast =
850         nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
851                              glsl_get_explicit_stride(var->type));
852      nir_ssa_def_rewrite_uses(&deref->dest.ssa,
853                               &deref_cast->dest.ssa);
854      nir_instr_remove(&deref->instr);
855
856      deref = deref_cast;
857      return true;
858   }
859   return false;
860}
861
862bool
863dxil_nir_lower_deref_ssbo(nir_shader *nir)
864{
865   bool progress = false;
866
867   foreach_list_typed(nir_function, func, node, &nir->functions) {
868      if (!func->is_entrypoint)
869         continue;
870      assert(func->impl);
871
872      nir_builder b;
873      nir_builder_init(&b, func->impl);
874
875      nir_foreach_block(block, func->impl) {
876         nir_foreach_instr_safe(instr, block) {
877            if (instr->type != nir_instr_type_deref)
878               continue;
879
880            nir_deref_instr *deref = nir_instr_as_deref(instr);
881
882            if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
883                (deref->deref_type != nir_deref_type_var &&
884                 deref->deref_type != nir_deref_type_cast))
885               continue;
886
887            progress |= lower_deref_ssbo(&b, deref);
888         }
889      }
890   }
891
892   return progress;
893}
894
895static bool
896lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
897{
898   const nir_op_info *info = &nir_op_infos[alu->op];
899   bool progress = false;
900
901   b->cursor = nir_before_instr(&alu->instr);
902
903   for (unsigned i = 0; i < info->num_inputs; i++) {
904      nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
905
906      if (!deref)
907         continue;
908
909      nir_deref_path path;
910      nir_deref_path_init(&path, deref, NULL);
911      nir_deref_instr *root_deref = path.path[0];
912      nir_deref_path_finish(&path);
913
914      if (root_deref->deref_type != nir_deref_type_cast)
915         continue;
916
917      nir_ssa_def *ptr =
918         nir_iadd(b, root_deref->parent.ssa,
919                     nir_build_deref_offset(b, deref, cl_type_size_align));
920      nir_instr_rewrite_src(&alu->instr, &alu->src[i].src, nir_src_for_ssa(ptr));
921      progress = true;
922   }
923
924   return progress;
925}
926
927bool
928dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
929{
930   bool progress = false;
931
932   foreach_list_typed(nir_function, func, node, &nir->functions) {
933      if (!func->is_entrypoint)
934         continue;
935      assert(func->impl);
936
937      bool progress = false;
938      nir_builder b;
939      nir_builder_init(&b, func->impl);
940
941      nir_foreach_block(block, func->impl) {
942         nir_foreach_instr_safe(instr, block) {
943            if (instr->type != nir_instr_type_alu)
944               continue;
945
946            nir_alu_instr *alu = nir_instr_as_alu(instr);
947            progress |= lower_alu_deref_srcs(&b, alu);
948         }
949      }
950   }
951
952   return progress;
953}
954
955static nir_ssa_def *
956memcpy_load_deref_elem(nir_builder *b, nir_deref_instr *parent,
957                       nir_ssa_def *index)
958{
959   nir_deref_instr *deref;
960
961   index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
962   assert(parent->deref_type == nir_deref_type_cast);
963   deref = nir_build_deref_ptr_as_array(b, parent, index);
964
965   return nir_load_deref(b, deref);
966}
967
968static void
969memcpy_store_deref_elem(nir_builder *b, nir_deref_instr *parent,
970                        nir_ssa_def *index, nir_ssa_def *value)
971{
972   nir_deref_instr *deref;
973
974   index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
975   assert(parent->deref_type == nir_deref_type_cast);
976   deref = nir_build_deref_ptr_as_array(b, parent, index);
977   nir_store_deref(b, deref, value, 1);
978}
979
980static bool
981lower_memcpy_deref(nir_builder *b, nir_intrinsic_instr *intr)
982{
983   nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
984   nir_deref_instr *src_deref = nir_src_as_deref(intr->src[1]);
985   assert(intr->src[2].is_ssa);
986   nir_ssa_def *num_bytes = intr->src[2].ssa;
987
988   assert(dst_deref && src_deref);
989
990   b->cursor = nir_after_instr(&intr->instr);
991
992   dst_deref = nir_build_deref_cast(b, &dst_deref->dest.ssa, dst_deref->modes,
993                                       glsl_uint8_t_type(), 1);
994   src_deref = nir_build_deref_cast(b, &src_deref->dest.ssa, src_deref->modes,
995                                       glsl_uint8_t_type(), 1);
996
997   /*
998    * We want to avoid 64b instructions, so let's assume we'll always be
999    * passed a value that fits in a 32b type and truncate the 64b value.
1000    */
1001   num_bytes = nir_u2u32(b, num_bytes);
1002
1003   nir_variable *loop_index_var =
1004     nir_local_variable_create(b->impl, glsl_uint_type(), "loop_index");
1005   nir_deref_instr *loop_index_deref = nir_build_deref_var(b, loop_index_var);
1006   nir_store_deref(b, loop_index_deref, nir_imm_int(b, 0), 1);
1007
1008   nir_loop *loop = nir_push_loop(b);
1009   nir_ssa_def *loop_index = nir_load_deref(b, loop_index_deref);
1010   nir_ssa_def *cmp = nir_ige(b, loop_index, num_bytes);
1011   nir_if *loop_check = nir_push_if(b, cmp);
1012   nir_jump(b, nir_jump_break);
1013   nir_pop_if(b, loop_check);
1014   nir_ssa_def *val = memcpy_load_deref_elem(b, src_deref, loop_index);
1015   memcpy_store_deref_elem(b, dst_deref, loop_index, val);
1016   nir_store_deref(b, loop_index_deref, nir_iadd_imm(b, loop_index, 1), 1);
1017   nir_pop_loop(b, loop);
1018   nir_instr_remove(&intr->instr);
1019   return true;
1020}
1021
1022bool
1023dxil_nir_lower_memcpy_deref(nir_shader *nir)
1024{
1025   bool progress = false;
1026
1027   foreach_list_typed(nir_function, func, node, &nir->functions) {
1028      if (!func->is_entrypoint)
1029         continue;
1030      assert(func->impl);
1031
1032      nir_builder b;
1033      nir_builder_init(&b, func->impl);
1034
1035      nir_foreach_block(block, func->impl) {
1036         nir_foreach_instr_safe(instr, block) {
1037            if (instr->type != nir_instr_type_intrinsic)
1038               continue;
1039
1040            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1041
1042            if (intr->intrinsic == nir_intrinsic_memcpy_deref)
1043               progress |= lower_memcpy_deref(&b, intr);
1044         }
1045      }
1046   }
1047
1048   return progress;
1049}
1050
1051static void
1052cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
1053{
1054   nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
1055   int num_components = 0;
1056   int old_bit_size = phi->dest.ssa.bit_size;
1057
1058   nir_op upcast_op = nir_type_conversion_op(nir_type_uint | old_bit_size,
1059                                             nir_type_uint | new_bit_size,
1060                                             nir_rounding_mode_undef);
1061   nir_op downcast_op = nir_type_conversion_op(nir_type_uint | new_bit_size,
1062                                               nir_type_uint | old_bit_size,
1063                                               nir_rounding_mode_undef);
1064
1065   nir_foreach_phi_src(src, phi) {
1066      assert(num_components == 0 || num_components == src->src.ssa->num_components);
1067      num_components = src->src.ssa->num_components;
1068
1069      b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
1070
1071      nir_ssa_def *cast = nir_build_alu(b, upcast_op, src->src.ssa, NULL, NULL, NULL);
1072      nir_phi_instr_add_src(lowered, src->pred, nir_src_for_ssa(cast));
1073   }
1074
1075   nir_ssa_dest_init(&lowered->instr, &lowered->dest,
1076                     num_components, new_bit_size, NULL);
1077
1078   b->cursor = nir_before_instr(&phi->instr);
1079   nir_builder_instr_insert(b, &lowered->instr);
1080
1081   b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
1082   nir_ssa_def *result = nir_build_alu(b, downcast_op, &lowered->dest.ssa, NULL, NULL, NULL);
1083
1084   nir_ssa_def_rewrite_uses(&phi->dest.ssa, result);
1085   nir_instr_remove(&phi->instr);
1086}
1087
1088static bool
1089upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
1090{
1091   nir_builder b;
1092   nir_builder_init(&b, impl);
1093   bool progress = false;
1094
1095   nir_foreach_block_reverse(block, impl) {
1096      nir_foreach_instr_safe(instr, block) {
1097         if (instr->type != nir_instr_type_phi)
1098            continue;
1099
1100         nir_phi_instr *phi = nir_instr_as_phi(instr);
1101         assert(phi->dest.is_ssa);
1102
1103         if (phi->dest.ssa.bit_size == 1 ||
1104             phi->dest.ssa.bit_size >= min_bit_size)
1105            continue;
1106
1107         cast_phi(&b, phi, min_bit_size);
1108         progress = true;
1109      }
1110   }
1111
1112   if (progress) {
1113      nir_metadata_preserve(impl, nir_metadata_block_index |
1114                                  nir_metadata_dominance);
1115   } else {
1116      nir_metadata_preserve(impl, nir_metadata_all);
1117   }
1118
1119   return progress;
1120}
1121
1122bool
1123dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
1124{
1125   bool progress = false;
1126
1127   nir_foreach_function(function, shader) {
1128      if (function->impl)
1129         progress |= upcast_phi_impl(function->impl, min_bit_size);
1130   }
1131
1132   return progress;
1133}
1134
1135struct dxil_nir_split_clip_cull_distance_params {
1136   nir_variable *new_var;
1137   nir_shader *shader;
1138};
1139
1140/* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
1141 * In DXIL, clip and cull distances are up to 2 float4s combined.
1142 * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
1143 * we can't, and have to accept a "compact" array of scalar floats.
1144 *
1145 * To help emitting a valid input signature for this case, split the variables so that they
1146 * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
1147 */
1148static bool
1149dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
1150                                        nir_instr *instr,
1151                                        void *cb_data)
1152{
1153   struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
1154   nir_variable *new_var = params->new_var;
1155
1156   if (instr->type != nir_instr_type_deref)
1157      return false;
1158
1159   nir_deref_instr *deref = nir_instr_as_deref(instr);
1160   nir_variable *var = nir_deref_instr_get_variable(deref);
1161   if (!var ||
1162       var->data.location < VARYING_SLOT_CLIP_DIST0 ||
1163       var->data.location > VARYING_SLOT_CULL_DIST1 ||
1164       !var->data.compact)
1165      return false;
1166
1167   /* The location should only be inside clip distance, because clip
1168    * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
1169    */
1170   assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
1171          var->data.location == VARYING_SLOT_CLIP_DIST1);
1172
1173   /* The deref chain to the clip/cull variables should be simple, just the
1174    * var and an array with a constant index, otherwise more lowering/optimization
1175    * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
1176    * split_var_copies, and/or lower_var_copies
1177    */
1178   assert(deref->deref_type == nir_deref_type_var ||
1179          deref->deref_type == nir_deref_type_array);
1180
1181   b->cursor = nir_before_instr(instr);
1182   if (!new_var) {
1183      /* Update lengths for new and old vars */
1184      int old_length = glsl_array_size(var->type);
1185      int new_length = (old_length + var->data.location_frac) - 4;
1186      old_length -= new_length;
1187
1188      /* The existing variable fits in the float4 */
1189      if (new_length <= 0)
1190         return false;
1191
1192      new_var = nir_variable_clone(var, params->shader);
1193      nir_shader_add_variable(params->shader, new_var);
1194      assert(glsl_get_base_type(glsl_get_array_element(var->type)) == GLSL_TYPE_FLOAT);
1195      var->type = glsl_array_type(glsl_float_type(), old_length, 0);
1196      new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
1197      new_var->data.location++;
1198      new_var->data.location_frac = 0;
1199      params->new_var = new_var;
1200   }
1201
1202   /* Update the type for derefs of the old var */
1203   if (deref->deref_type == nir_deref_type_var) {
1204      deref->type = var->type;
1205      return false;
1206   }
1207
1208   nir_const_value *index = nir_src_as_const_value(deref->arr.index);
1209   assert(index);
1210
1211   /* Treat this array as a vector starting at the component index in location_frac,
1212    * so if location_frac is 1 and index is 0, then it's accessing the 'y' component
1213    * of the vector. If index + location_frac is >= 4, there's no component there,
1214    * so we need to add a new variable and adjust the index.
1215    */
1216   unsigned total_index = index->u32 + var->data.location_frac;
1217   if (total_index < 4)
1218      return false;
1219
1220   nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
1221   nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_var_deref, nir_imm_int(b, total_index % 4));
1222   nir_ssa_def_rewrite_uses(&deref->dest.ssa, &new_array_deref->dest.ssa);
1223   return true;
1224}
1225
1226bool
1227dxil_nir_split_clip_cull_distance(nir_shader *shader)
1228{
1229   struct dxil_nir_split_clip_cull_distance_params params = {
1230      .new_var = NULL,
1231      .shader = shader,
1232   };
1233   nir_shader_instructions_pass(shader,
1234                                dxil_nir_split_clip_cull_distance_instr,
1235                                nir_metadata_block_index |
1236                                nir_metadata_dominance |
1237                                nir_metadata_loop_analysis,
1238                                &params);
1239   return params.new_var != NULL;
1240}
1241
1242static bool
1243dxil_nir_lower_double_math_instr(nir_builder *b,
1244                                 nir_instr *instr,
1245                                 UNUSED void *cb_data)
1246{
1247   if (instr->type != nir_instr_type_alu)
1248      return false;
1249
1250   nir_alu_instr *alu = nir_instr_as_alu(instr);
1251
1252   /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1253    * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1254    * then try to bitcast to double (not expressible in HLSL, but it is in other
1255    * source languages), this would unpack the integer and repack as a double, when
1256    * we probably want to just send the bitcast through to the backend.
1257    */
1258
1259   b->cursor = nir_before_instr(&alu->instr);
1260
1261   bool progress = false;
1262   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1263      if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1264          alu->src[i].src.ssa->bit_size == 64) {
1265         nir_ssa_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[0]);
1266         nir_ssa_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1267         nir_ssa_def *repacked_double = nir_pack_double_2x32_dxil(b, unpacked_double);
1268         nir_instr_rewrite_src_ssa(instr, &alu->src[i].src, repacked_double);
1269         memset(alu->src[i].swizzle, 0, ARRAY_SIZE(alu->src[i].swizzle));
1270         progress = true;
1271      }
1272   }
1273
1274   if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1275       alu->dest.dest.ssa.bit_size == 64) {
1276      b->cursor = nir_after_instr(&alu->instr);
1277      nir_ssa_def *packed_double = &alu->dest.dest.ssa;
1278      nir_ssa_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1279      nir_ssa_def *repacked_double = nir_pack_64_2x32(b, unpacked_double);
1280      nir_ssa_def_rewrite_uses_after(packed_double, repacked_double, unpacked_double->parent_instr);
1281      progress = true;
1282   }
1283
1284   return progress;
1285}
1286
1287bool
1288dxil_nir_lower_double_math(nir_shader *shader)
1289{
1290   return nir_shader_instructions_pass(shader,
1291                                       dxil_nir_lower_double_math_instr,
1292                                       nir_metadata_block_index |
1293                                       nir_metadata_dominance |
1294                                       nir_metadata_loop_analysis,
1295                                       NULL);
1296}
1297
1298typedef struct {
1299   gl_system_value *values;
1300   uint32_t count;
1301} zero_system_values_state;
1302
1303static bool
1304lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1305{
1306   if (instr->type != nir_instr_type_intrinsic) {
1307      return false;
1308   }
1309
1310   nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1311
1312   /* All the intrinsics we care about are loads */
1313   if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1314      return false;
1315
1316   assert(intrin->dest.is_ssa);
1317
1318   zero_system_values_state* state = (zero_system_values_state*)cb_state;
1319   for (uint32_t i = 0; i < state->count; ++i) {
1320      gl_system_value value = state->values[i];
1321      nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1322
1323      if (intrin->intrinsic == value_op) {
1324         return true;
1325      } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1326         nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1327         if (!nir_deref_mode_is(deref, nir_var_system_value))
1328            return false;
1329
1330         nir_variable* var = deref->var;
1331         if (var->data.location == value) {
1332            return true;
1333         }
1334      }
1335   }
1336
1337   return false;
1338}
1339
1340static nir_ssa_def*
1341lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1342{
1343   return nir_imm_int(b, 0);
1344}
1345
1346bool
1347dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1348                                     gl_system_value* system_values,
1349                                     uint32_t count)
1350{
1351   zero_system_values_state state = { system_values, count };
1352   return nir_shader_lower_instructions(shader,
1353      lower_system_value_to_zero_filter,
1354      lower_system_value_to_zero_instr,
1355      &state);
1356}
1357
1358static const struct glsl_type *
1359get_bare_samplers_for_type(const struct glsl_type *type)
1360{
1361   if (glsl_type_is_sampler(type)) {
1362      if (glsl_sampler_type_is_shadow(type))
1363         return glsl_bare_shadow_sampler_type();
1364      else
1365         return glsl_bare_sampler_type();
1366   } else if (glsl_type_is_array(type)) {
1367      return glsl_array_type(
1368         get_bare_samplers_for_type(glsl_get_array_element(type)),
1369         glsl_get_length(type),
1370         0 /*explicit size*/);
1371   }
1372   assert(!"Unexpected type");
1373   return NULL;
1374}
1375
1376static bool
1377redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1378{
1379   if (instr->type != nir_instr_type_tex)
1380      return false;
1381
1382   nir_tex_instr *tex = nir_instr_as_tex(instr);
1383   if (!nir_tex_instr_need_sampler(tex))
1384      return false;
1385
1386   int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1387   if (sampler_idx == -1) {
1388      /* No derefs, must be using indices */
1389      nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1390
1391      /* Already have a bare sampler here */
1392      if (bare_sampler)
1393         return false;
1394
1395      nir_variable *typed_sampler = NULL;
1396      nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1397         if (var->data.binding <= tex->sampler_index &&
1398             var->data.binding + glsl_type_get_sampler_count(var->type) > tex->sampler_index) {
1399            /* Already have a bare sampler for this binding, add it to the table */
1400            if (glsl_get_sampler_result_type(glsl_without_array(var->type)) == GLSL_TYPE_VOID) {
1401               _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1402               return false;
1403            }
1404
1405            typed_sampler = var;
1406         }
1407      }
1408
1409      /* Clone the typed sampler to a bare sampler and we're done */
1410      assert(typed_sampler);
1411      bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1412      bare_sampler->type = get_bare_samplers_for_type(typed_sampler->type);
1413      nir_shader_add_variable(b->shader, bare_sampler);
1414      _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1415      return true;
1416   }
1417
1418   /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1419   nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1420   nir_deref_path path;
1421   nir_deref_path_init(&path, final_deref, NULL);
1422
1423   nir_deref_instr *old_tail = path.path[0];
1424   assert(old_tail->deref_type == nir_deref_type_var);
1425   nir_variable *old_var = old_tail->var;
1426   if (glsl_get_sampler_result_type(glsl_without_array(old_var->type)) == GLSL_TYPE_VOID) {
1427      nir_deref_path_finish(&path);
1428      return false;
1429   }
1430
1431   nir_variable *new_var = _mesa_hash_table_u64_search(data, old_var->data.binding);
1432   if (!new_var) {
1433      new_var = nir_variable_clone(old_var, b->shader);
1434      new_var->type = get_bare_samplers_for_type(old_var->type);
1435      nir_shader_add_variable(b->shader, new_var);
1436      _mesa_hash_table_u64_insert(data, old_var->data.binding, new_var);
1437   }
1438
1439   b->cursor = nir_after_instr(&old_tail->instr);
1440   nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1441
1442   for (unsigned i = 1; path.path[i]; ++i) {
1443      b->cursor = nir_after_instr(&path.path[i]->instr);
1444      new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1445   }
1446
1447   nir_deref_path_finish(&path);
1448   nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[sampler_idx].src, &new_tail->dest.ssa);
1449
1450   return true;
1451}
1452
1453bool
1454dxil_nir_create_bare_samplers(nir_shader *nir)
1455{
1456   struct hash_table_u64 *sampler_to_bare = _mesa_hash_table_u64_create(NULL);
1457
1458   bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1459      nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, sampler_to_bare);
1460
1461   _mesa_hash_table_u64_destroy(sampler_to_bare);
1462   return progress;
1463}
1464
1465
1466static bool
1467lower_bool_input_filter(const nir_instr *instr,
1468                        UNUSED const void *_options)
1469{
1470   if (instr->type != nir_instr_type_intrinsic)
1471      return false;
1472
1473   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1474   if (intr->intrinsic == nir_intrinsic_load_front_face)
1475      return true;
1476
1477   if (intr->intrinsic == nir_intrinsic_load_deref) {
1478      nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1479      nir_variable *var = nir_deref_instr_get_variable(deref);
1480      return var->data.mode == nir_var_shader_in &&
1481             glsl_get_base_type(var->type) == GLSL_TYPE_BOOL;
1482   }
1483
1484   return false;
1485}
1486
1487static nir_ssa_def *
1488lower_bool_input_impl(nir_builder *b, nir_instr *instr,
1489                      UNUSED void *_options)
1490{
1491   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1492
1493   if (intr->intrinsic == nir_intrinsic_load_deref) {
1494      nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1495      nir_variable *var = nir_deref_instr_get_variable(deref);
1496
1497      /* rewrite var->type */
1498      var->type = glsl_vector_type(GLSL_TYPE_UINT,
1499                                   glsl_get_vector_elements(var->type));
1500      deref->type = var->type;
1501   }
1502
1503   intr->dest.ssa.bit_size = 32;
1504   return nir_i2b1(b, &intr->dest.ssa);
1505}
1506
1507bool
1508dxil_nir_lower_bool_input(struct nir_shader *s)
1509{
1510   return nir_shader_lower_instructions(s, lower_bool_input_filter,
1511                                        lower_bool_input_impl, NULL);
1512}
1513
1514/* Comparison function to sort io values so that first come normal varyings,
1515 * then system values, and then system generated values.
1516 */
1517static int
1518variable_location_cmp(const nir_variable* a, const nir_variable* b)
1519{
1520   // Sort by driver_location, location, then index
1521   return a->data.driver_location != b->data.driver_location ?
1522            a->data.driver_location - b->data.driver_location :
1523            a->data.location !=  b->data.location ?
1524               a->data.location - b->data.location :
1525               a->data.index - b->data.index;
1526}
1527
1528/* Order varyings according to driver location */
1529uint64_t
1530dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1531{
1532   nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1533
1534   uint64_t result = 0;
1535   nir_foreach_variable_with_modes(var, s, modes) {
1536      result |= 1ull << var->data.location;
1537   }
1538   return result;
1539}
1540
1541/* Sort PS outputs so that color outputs come first */
1542void
1543dxil_sort_ps_outputs(nir_shader* s)
1544{
1545   nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1546      /* We use the driver_location here to avoid introducing a new
1547       * struct or member variable here. The true, updated driver location
1548       * will be written below, after sorting */
1549      switch (var->data.location) {
1550      case FRAG_RESULT_DEPTH:
1551         var->data.driver_location = 1;
1552         break;
1553      case FRAG_RESULT_STENCIL:
1554         var->data.driver_location = 2;
1555         break;
1556      case FRAG_RESULT_SAMPLE_MASK:
1557         var->data.driver_location = 3;
1558         break;
1559      default:
1560         var->data.driver_location = 0;
1561      }
1562   }
1563
1564   nir_sort_variables_with_modes(s, variable_location_cmp,
1565                                 nir_var_shader_out);
1566
1567   unsigned driver_loc = 0;
1568   nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1569      var->data.driver_location = driver_loc++;
1570   }
1571}
1572
1573/* Order between stage values so that normal varyings come first,
1574 * then sysvalues and then system generated values.
1575 */
1576uint64_t
1577dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1578   uint64_t other_stage_mask)
1579{
1580   nir_foreach_variable_with_modes_safe(var, s, modes) {
1581      /* We use the driver_location here to avoid introducing a new
1582       * struct or member variable here. The true, updated driver location
1583       * will be written below, after sorting */
1584      var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask);
1585   }
1586
1587   nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1588
1589   uint64_t result = 0;
1590   unsigned driver_loc = 0;
1591   nir_foreach_variable_with_modes(var, s, modes) {
1592      result |= 1ull << var->data.location;
1593      var->data.driver_location = driver_loc++;
1594   }
1595   return result;
1596}
1597