1/*
2 * Copyright © 2019 Google, Inc.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 * SOFTWARE.
22 */
23
24#include "compiler/nir/nir_builder.h"
25#include "ir3_compiler.h"
26#include "ir3_nir.h"
27
28struct state {
29   uint32_t topology;
30
31   struct primitive_map {
32      unsigned loc[32 + 4]; /* +POSITION +PSIZE +CLIP_DIST0 +CLIP_DIST1 */
33      unsigned stride;
34   } map;
35
36   nir_ssa_def *header;
37
38   nir_variable *vertex_count_var;
39   nir_variable *emitted_vertex_var;
40   nir_variable *vertex_flags_out;
41
42   struct exec_list old_outputs;
43   struct exec_list new_outputs;
44   struct exec_list emit_outputs;
45
46   /* tess ctrl shader on a650 gets the local primitive id at different bits: */
47   unsigned local_primitive_id_start;
48};
49
50static nir_ssa_def *
51bitfield_extract(nir_builder *b, nir_ssa_def *v, uint32_t start, uint32_t mask)
52{
53   return nir_iand(b, nir_ushr(b, v, nir_imm_int(b, start)),
54                   nir_imm_int(b, mask));
55}
56
57static nir_ssa_def *
58build_invocation_id(nir_builder *b, struct state *state)
59{
60   return bitfield_extract(b, state->header, 11, 31);
61}
62
63static nir_ssa_def *
64build_vertex_id(nir_builder *b, struct state *state)
65{
66   return bitfield_extract(b, state->header, 6, 31);
67}
68
69static nir_ssa_def *
70build_local_primitive_id(nir_builder *b, struct state *state)
71{
72   return bitfield_extract(b, state->header, state->local_primitive_id_start,
73                           63);
74}
75
76static bool
77is_tess_levels(gl_varying_slot slot)
78{
79   return (slot == VARYING_SLOT_PRIMITIVE_ID ||
80           slot == VARYING_SLOT_TESS_LEVEL_OUTER ||
81           slot == VARYING_SLOT_TESS_LEVEL_INNER);
82}
83
84/* Return a deterministic index for varyings. We can't rely on driver_location
85 * to be correct without linking the different stages first, so we create
86 * "primitive maps" where the producer decides on the location of each varying
87 * slot and then exports a per-slot array to the consumer. This compacts the
88 * gl_varying_slot space down a bit so that the primitive maps aren't too
89 * large.
90 *
91 * Note: per-patch varyings are currently handled separately, without any
92 * compacting.
93 *
94 * TODO: We could probably use the driver_location's directly in the non-SSO
95 * (Vulkan) case.
96 */
97
98static unsigned
99shader_io_get_unique_index(gl_varying_slot slot)
100{
101   if (slot == VARYING_SLOT_POS)
102      return 0;
103   if (slot == VARYING_SLOT_PSIZ)
104      return 1;
105   if (slot == VARYING_SLOT_CLIP_DIST0)
106      return 2;
107   if (slot == VARYING_SLOT_CLIP_DIST1)
108      return 3;
109   if (slot >= VARYING_SLOT_VAR0 && slot <= VARYING_SLOT_VAR31)
110      return 4 + (slot - VARYING_SLOT_VAR0);
111   unreachable("illegal slot in get unique index\n");
112}
113
114static nir_ssa_def *
115build_local_offset(nir_builder *b, struct state *state, nir_ssa_def *vertex,
116                   uint32_t location, uint32_t comp, nir_ssa_def *offset)
117{
118   nir_ssa_def *primitive_stride = nir_load_vs_primitive_stride_ir3(b);
119   nir_ssa_def *primitive_offset =
120      nir_imul24(b, build_local_primitive_id(b, state), primitive_stride);
121   nir_ssa_def *attr_offset;
122   nir_ssa_def *vertex_stride;
123   unsigned index = shader_io_get_unique_index(location);
124
125   switch (b->shader->info.stage) {
126   case MESA_SHADER_VERTEX:
127   case MESA_SHADER_TESS_EVAL:
128      vertex_stride = nir_imm_int(b, state->map.stride * 4);
129      attr_offset = nir_imm_int(b, state->map.loc[index] + 4 * comp);
130      break;
131   case MESA_SHADER_TESS_CTRL:
132   case MESA_SHADER_GEOMETRY:
133      vertex_stride = nir_load_vs_vertex_stride_ir3(b);
134      attr_offset = nir_iadd(b, nir_load_primitive_location_ir3(b, index),
135                             nir_imm_int(b, comp * 4));
136      break;
137   default:
138      unreachable("bad shader stage");
139   }
140
141   nir_ssa_def *vertex_offset = nir_imul24(b, vertex, vertex_stride);
142
143   return nir_iadd(
144      b, nir_iadd(b, primitive_offset, vertex_offset),
145      nir_iadd(b, attr_offset, nir_ishl(b, offset, nir_imm_int(b, 4))));
146}
147
148static nir_intrinsic_instr *
149replace_intrinsic(nir_builder *b, nir_intrinsic_instr *intr,
150                  nir_intrinsic_op op, nir_ssa_def *src0, nir_ssa_def *src1,
151                  nir_ssa_def *src2)
152{
153   nir_intrinsic_instr *new_intr = nir_intrinsic_instr_create(b->shader, op);
154
155   new_intr->src[0] = nir_src_for_ssa(src0);
156   if (src1)
157      new_intr->src[1] = nir_src_for_ssa(src1);
158   if (src2)
159      new_intr->src[2] = nir_src_for_ssa(src2);
160
161   new_intr->num_components = intr->num_components;
162
163   if (nir_intrinsic_infos[op].has_dest)
164      nir_ssa_dest_init(&new_intr->instr, &new_intr->dest, intr->num_components,
165                        32, NULL);
166
167   nir_builder_instr_insert(b, &new_intr->instr);
168
169   if (nir_intrinsic_infos[op].has_dest)
170      nir_ssa_def_rewrite_uses(&intr->dest.ssa, &new_intr->dest.ssa);
171
172   nir_instr_remove(&intr->instr);
173
174   return new_intr;
175}
176
177static void
178build_primitive_map(nir_shader *shader, struct primitive_map *map)
179{
180   /* All interfaces except the TCS <-> TES interface use ldlw, which takes
181    * an offset in bytes, so each vec4 slot is 16 bytes. TCS <-> TES uses
182    * ldg, which takes an offset in dwords, but each per-vertex slot has
183    * space for every vertex, and there's space at the beginning for
184    * per-patch varyings.
185    */
186   unsigned slot_size = 16, start = 0;
187   if (shader->info.stage == MESA_SHADER_TESS_CTRL) {
188      slot_size = shader->info.tess.tcs_vertices_out * 4;
189      start = util_last_bit(shader->info.patch_outputs_written) * 4;
190   }
191
192   uint64_t mask = shader->info.outputs_written;
193   unsigned loc = start;
194   while (mask) {
195      int location = u_bit_scan64(&mask);
196      if (is_tess_levels(location))
197         continue;
198
199      unsigned index = shader_io_get_unique_index(location);
200      map->loc[index] = loc;
201      loc += slot_size;
202   }
203
204   map->stride = loc;
205   /* Use units of dwords for the stride. */
206   if (shader->info.stage != MESA_SHADER_TESS_CTRL)
207      map->stride /= 4;
208}
209
210/* For shader stages that receive a primitive map, calculate how big it should
211 * be.
212 */
213
214static unsigned
215calc_primitive_map_size(nir_shader *shader)
216{
217   uint64_t mask = shader->info.inputs_read;
218   unsigned max_index = 0;
219   while (mask) {
220      int location = u_bit_scan64(&mask);
221
222      if (is_tess_levels(location))
223         continue;
224
225      unsigned index = shader_io_get_unique_index(location);
226      max_index = MAX2(max_index, index + 1);
227   }
228
229   return max_index;
230}
231
232static void
233lower_block_to_explicit_output(nir_block *block, nir_builder *b,
234                               struct state *state)
235{
236   nir_foreach_instr_safe (instr, block) {
237      if (instr->type != nir_instr_type_intrinsic)
238         continue;
239
240      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
241
242      switch (intr->intrinsic) {
243      case nir_intrinsic_store_output: {
244         // src[] = { value, offset }.
245
246         /* nir_lower_io_to_temporaries replaces all access to output
247          * variables with temp variables and then emits a nir_copy_var at
248          * the end of the shader.  Thus, we should always get a full wrmask
249          * here.
250          */
251         assert(
252            util_is_power_of_two_nonzero(nir_intrinsic_write_mask(intr) + 1));
253
254         b->cursor = nir_instr_remove(&intr->instr);
255
256         nir_ssa_def *vertex_id = build_vertex_id(b, state);
257         nir_ssa_def *offset = build_local_offset(
258            b, state, vertex_id, nir_intrinsic_io_semantics(intr).location,
259            nir_intrinsic_component(intr), intr->src[1].ssa);
260
261         nir_store_shared_ir3(b, intr->src[0].ssa, offset);
262         break;
263      }
264
265      default:
266         break;
267      }
268   }
269}
270
271static nir_ssa_def *
272local_thread_id(nir_builder *b)
273{
274   return bitfield_extract(b, nir_load_gs_header_ir3(b), 16, 1023);
275}
276
277void
278ir3_nir_lower_to_explicit_output(nir_shader *shader,
279                                 struct ir3_shader_variant *v,
280                                 unsigned topology)
281{
282   struct state state = {};
283
284   build_primitive_map(shader, &state.map);
285   memcpy(v->output_loc, state.map.loc, sizeof(v->output_loc));
286
287   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
288   assert(impl);
289
290   nir_builder b;
291   nir_builder_init(&b, impl);
292   b.cursor = nir_before_cf_list(&impl->body);
293
294   if (v->type == MESA_SHADER_VERTEX && topology != IR3_TESS_NONE)
295      state.header = nir_load_tcs_header_ir3(&b);
296   else
297      state.header = nir_load_gs_header_ir3(&b);
298
299   nir_foreach_block_safe (block, impl)
300      lower_block_to_explicit_output(block, &b, &state);
301
302   nir_metadata_preserve(impl,
303                         nir_metadata_block_index | nir_metadata_dominance);
304
305   v->output_size = state.map.stride;
306}
307
308static void
309lower_block_to_explicit_input(nir_block *block, nir_builder *b,
310                              struct state *state)
311{
312   nir_foreach_instr_safe (instr, block) {
313      if (instr->type != nir_instr_type_intrinsic)
314         continue;
315
316      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
317
318      switch (intr->intrinsic) {
319      case nir_intrinsic_load_per_vertex_input: {
320         // src[] = { vertex, offset }.
321
322         b->cursor = nir_before_instr(&intr->instr);
323
324         nir_ssa_def *offset = build_local_offset(
325            b, state,
326            intr->src[0].ssa, // this is typically gl_InvocationID
327            nir_intrinsic_io_semantics(intr).location,
328            nir_intrinsic_component(intr), intr->src[1].ssa);
329
330         replace_intrinsic(b, intr, nir_intrinsic_load_shared_ir3, offset, NULL,
331                           NULL);
332         break;
333      }
334
335      case nir_intrinsic_load_invocation_id: {
336         b->cursor = nir_before_instr(&intr->instr);
337
338         nir_ssa_def *iid = build_invocation_id(b, state);
339         nir_ssa_def_rewrite_uses(&intr->dest.ssa, iid);
340         nir_instr_remove(&intr->instr);
341         break;
342      }
343
344      default:
345         break;
346      }
347   }
348}
349
350void
351ir3_nir_lower_to_explicit_input(nir_shader *shader,
352                                struct ir3_shader_variant *v)
353{
354   struct state state = {};
355
356   /* when using stl/ldl (instead of stlw/ldlw) for linking VS and HS,
357    * HS uses a different primitive id, which starts at bit 16 in the header
358    */
359   if (shader->info.stage == MESA_SHADER_TESS_CTRL &&
360       v->shader->compiler->tess_use_shared)
361      state.local_primitive_id_start = 16;
362
363   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
364   assert(impl);
365
366   nir_builder b;
367   nir_builder_init(&b, impl);
368   b.cursor = nir_before_cf_list(&impl->body);
369
370   if (shader->info.stage == MESA_SHADER_GEOMETRY)
371      state.header = nir_load_gs_header_ir3(&b);
372   else
373      state.header = nir_load_tcs_header_ir3(&b);
374
375   nir_foreach_block_safe (block, impl)
376      lower_block_to_explicit_input(block, &b, &state);
377
378   v->input_size = calc_primitive_map_size(shader);
379}
380
381static nir_ssa_def *
382build_tcs_out_vertices(nir_builder *b)
383{
384   if (b->shader->info.stage == MESA_SHADER_TESS_CTRL)
385      return nir_imm_int(b, b->shader->info.tess.tcs_vertices_out);
386   else
387      return nir_load_patch_vertices_in(b);
388}
389
390static nir_ssa_def *
391build_per_vertex_offset(nir_builder *b, struct state *state,
392                        nir_ssa_def *vertex, uint32_t location, uint32_t comp,
393                        nir_ssa_def *offset)
394{
395   nir_ssa_def *patch_id = nir_load_rel_patch_id_ir3(b);
396   nir_ssa_def *patch_stride = nir_load_hs_patch_stride_ir3(b);
397   nir_ssa_def *patch_offset = nir_imul24(b, patch_id, patch_stride);
398   nir_ssa_def *attr_offset;
399
400   if (nir_src_is_const(nir_src_for_ssa(offset))) {
401      location += nir_src_as_uint(nir_src_for_ssa(offset));
402      offset = nir_imm_int(b, 0);
403   } else {
404      /* Offset is in vec4's, but we need it in unit of components for the
405       * load/store_global_ir3 offset.
406       */
407      offset = nir_ishl(b, offset, nir_imm_int(b, 2));
408   }
409
410   nir_ssa_def *vertex_offset;
411   if (vertex) {
412      unsigned index = shader_io_get_unique_index(location);
413      switch (b->shader->info.stage) {
414      case MESA_SHADER_TESS_CTRL:
415         attr_offset = nir_imm_int(b, state->map.loc[index] + comp);
416         break;
417      case MESA_SHADER_TESS_EVAL:
418         attr_offset = nir_iadd(b, nir_load_primitive_location_ir3(b, index),
419                                nir_imm_int(b, comp));
420         break;
421      default:
422         unreachable("bad shader state");
423      }
424
425      attr_offset = nir_iadd(b, attr_offset,
426                             nir_imul24(b, offset, build_tcs_out_vertices(b)));
427      vertex_offset = nir_ishl(b, vertex, nir_imm_int(b, 2));
428   } else {
429      assert(location >= VARYING_SLOT_PATCH0 &&
430             location <= VARYING_SLOT_TESS_MAX);
431      unsigned index = location - VARYING_SLOT_PATCH0;
432      attr_offset = nir_iadd(b, nir_imm_int(b, index * 4 + comp), offset);
433      vertex_offset = nir_imm_int(b, 0);
434   }
435
436   return nir_iadd(b, nir_iadd(b, patch_offset, attr_offset), vertex_offset);
437}
438
439static nir_ssa_def *
440build_patch_offset(nir_builder *b, struct state *state, uint32_t base,
441                   uint32_t comp, nir_ssa_def *offset)
442{
443   return build_per_vertex_offset(b, state, NULL, base, comp, offset);
444}
445
446static void
447tess_level_components(struct state *state, uint32_t *inner, uint32_t *outer)
448{
449   switch (state->topology) {
450   case IR3_TESS_TRIANGLES:
451      *inner = 1;
452      *outer = 3;
453      break;
454   case IR3_TESS_QUADS:
455      *inner = 2;
456      *outer = 4;
457      break;
458   case IR3_TESS_ISOLINES:
459      *inner = 0;
460      *outer = 2;
461      break;
462   default:
463      unreachable("bad");
464   }
465}
466
467static nir_ssa_def *
468build_tessfactor_base(nir_builder *b, gl_varying_slot slot, struct state *state)
469{
470   uint32_t inner_levels, outer_levels;
471   tess_level_components(state, &inner_levels, &outer_levels);
472
473   const uint32_t patch_stride = 1 + inner_levels + outer_levels;
474
475   nir_ssa_def *patch_id = nir_load_rel_patch_id_ir3(b);
476
477   nir_ssa_def *patch_offset =
478      nir_imul24(b, patch_id, nir_imm_int(b, patch_stride));
479
480   uint32_t offset;
481   switch (slot) {
482   case VARYING_SLOT_PRIMITIVE_ID:
483      offset = 0;
484      break;
485   case VARYING_SLOT_TESS_LEVEL_OUTER:
486      offset = 1;
487      break;
488   case VARYING_SLOT_TESS_LEVEL_INNER:
489      offset = 1 + outer_levels;
490      break;
491   default:
492      unreachable("bad");
493   }
494
495   return nir_iadd(b, patch_offset, nir_imm_int(b, offset));
496}
497
498static void
499lower_tess_ctrl_block(nir_block *block, nir_builder *b, struct state *state)
500{
501   nir_foreach_instr_safe (instr, block) {
502      if (instr->type != nir_instr_type_intrinsic)
503         continue;
504
505      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
506
507      switch (intr->intrinsic) {
508      case nir_intrinsic_load_per_vertex_output: {
509         // src[] = { vertex, offset }.
510
511         b->cursor = nir_before_instr(&intr->instr);
512
513         nir_ssa_def *address = nir_load_tess_param_base_ir3(b);
514         nir_ssa_def *offset = build_per_vertex_offset(
515            b, state, intr->src[0].ssa,
516            nir_intrinsic_io_semantics(intr).location,
517            nir_intrinsic_component(intr), intr->src[1].ssa);
518
519         replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
520                           offset, NULL);
521         break;
522      }
523
524      case nir_intrinsic_store_per_vertex_output: {
525         // src[] = { value, vertex, offset }.
526
527         b->cursor = nir_before_instr(&intr->instr);
528
529         /* sparse writemask not supported */
530         assert(
531            util_is_power_of_two_nonzero(nir_intrinsic_write_mask(intr) + 1));
532
533         nir_ssa_def *value = intr->src[0].ssa;
534         nir_ssa_def *address = nir_load_tess_param_base_ir3(b);
535         nir_ssa_def *offset = build_per_vertex_offset(
536            b, state, intr->src[1].ssa,
537            nir_intrinsic_io_semantics(intr).location,
538            nir_intrinsic_component(intr), intr->src[2].ssa);
539
540         replace_intrinsic(b, intr, nir_intrinsic_store_global_ir3, value,
541                           address, offset);
542
543         break;
544      }
545
546      case nir_intrinsic_load_output: {
547         // src[] = { offset }.
548
549         b->cursor = nir_before_instr(&intr->instr);
550
551         nir_ssa_def *address, *offset;
552
553         /* note if vectorization of the tess level loads ever happens:
554          * "ldg" across 16-byte boundaries can behave incorrectly if results
555          * are never used. most likely some issue with (sy) not properly
556          * syncing with values coming from a second memory transaction.
557          */
558         gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
559         if (is_tess_levels(location)) {
560            assert(intr->dest.ssa.num_components == 1);
561            address = nir_load_tess_factor_base_ir3(b);
562            offset = build_tessfactor_base(b, location, state);
563         } else {
564            address = nir_load_tess_param_base_ir3(b);
565            offset = build_patch_offset(b, state, location,
566                                        nir_intrinsic_component(intr),
567                                        intr->src[0].ssa);
568         }
569
570         replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
571                           offset, NULL);
572         break;
573      }
574
575      case nir_intrinsic_store_output: {
576         // src[] = { value, offset }.
577
578         /* write patch output to bo */
579
580         b->cursor = nir_before_instr(&intr->instr);
581
582         /* sparse writemask not supported */
583         assert(
584            util_is_power_of_two_nonzero(nir_intrinsic_write_mask(intr) + 1));
585
586         gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
587         if (is_tess_levels(location)) {
588            uint32_t inner_levels, outer_levels, levels;
589            tess_level_components(state, &inner_levels, &outer_levels);
590
591            assert(intr->src[0].ssa->num_components == 1);
592
593            nir_ssa_def *offset =
594               nir_iadd_imm(b, intr->src[1].ssa, nir_intrinsic_component(intr));
595
596            nir_if *nif = NULL;
597            if (location != VARYING_SLOT_PRIMITIVE_ID) {
598               /* with tess levels are defined as float[4] and float[2],
599                * but tess factor BO has smaller sizes for tris/isolines,
600                * so we have to discard any writes beyond the number of
601                * components for inner/outer levels
602                */
603               if (location == VARYING_SLOT_TESS_LEVEL_OUTER)
604                  levels = outer_levels;
605               else
606                  levels = inner_levels;
607
608               nif = nir_push_if(b, nir_ult(b, offset, nir_imm_int(b, levels)));
609            }
610
611            replace_intrinsic(
612               b, intr, nir_intrinsic_store_global_ir3, intr->src[0].ssa,
613               nir_load_tess_factor_base_ir3(b),
614               nir_iadd(b, offset, build_tessfactor_base(b, location, state)));
615
616            if (location != VARYING_SLOT_PRIMITIVE_ID) {
617               nir_pop_if(b, nif);
618            }
619         } else {
620            nir_ssa_def *address = nir_load_tess_param_base_ir3(b);
621            nir_ssa_def *offset = build_patch_offset(
622               b, state, location, nir_intrinsic_component(intr),
623               intr->src[1].ssa);
624
625            replace_intrinsic(b, intr, nir_intrinsic_store_global_ir3,
626                              intr->src[0].ssa, address, offset);
627         }
628         break;
629      }
630
631      default:
632         break;
633      }
634   }
635}
636
637static void
638emit_tess_epilouge(nir_builder *b, struct state *state)
639{
640   /* Insert endpatch instruction:
641    *
642    * TODO we should re-work this to use normal flow control.
643    */
644
645   nir_end_patch_ir3(b);
646}
647
648void
649ir3_nir_lower_tess_ctrl(nir_shader *shader, struct ir3_shader_variant *v,
650                        unsigned topology)
651{
652   struct state state = {.topology = topology};
653
654   if (shader_debug_enabled(shader->info.stage)) {
655      mesa_logi("NIR (before tess lowering) for %s shader:",
656                _mesa_shader_stage_to_string(shader->info.stage));
657      nir_log_shaderi(shader);
658   }
659
660   build_primitive_map(shader, &state.map);
661   memcpy(v->output_loc, state.map.loc, sizeof(v->output_loc));
662   v->output_size = state.map.stride;
663
664   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
665   assert(impl);
666
667   nir_builder b;
668   nir_builder_init(&b, impl);
669   b.cursor = nir_before_cf_list(&impl->body);
670
671   state.header = nir_load_tcs_header_ir3(&b);
672
673   /* If required, store gl_PrimitiveID. */
674   if (v->key.tcs_store_primid) {
675      b.cursor = nir_after_cf_list(&impl->body);
676
677      nir_store_output(&b, nir_load_primitive_id(&b), nir_imm_int(&b, 0),
678                       .io_semantics = {
679                           .location = VARYING_SLOT_PRIMITIVE_ID,
680                           .num_slots = 1
681                        });
682
683      b.cursor = nir_before_cf_list(&impl->body);
684   }
685
686   nir_foreach_block_safe (block, impl)
687      lower_tess_ctrl_block(block, &b, &state);
688
689   /* Now move the body of the TCS into a conditional:
690    *
691    *   if (gl_InvocationID < num_vertices)
692    *     // body
693    *
694    */
695
696   nir_cf_list body;
697   nir_cf_extract(&body, nir_before_cf_list(&impl->body),
698                  nir_after_cf_list(&impl->body));
699
700   b.cursor = nir_after_cf_list(&impl->body);
701
702   /* Re-emit the header, since the old one got moved into the if branch */
703   state.header = nir_load_tcs_header_ir3(&b);
704   nir_ssa_def *iid = build_invocation_id(&b, &state);
705
706   const uint32_t nvertices = shader->info.tess.tcs_vertices_out;
707   nir_ssa_def *cond = nir_ult(&b, iid, nir_imm_int(&b, nvertices));
708
709   nir_if *nif = nir_push_if(&b, cond);
710
711   nir_cf_reinsert(&body, b.cursor);
712
713   b.cursor = nir_after_cf_list(&nif->then_list);
714
715   /* Insert conditional exit for threads invocation id != 0 */
716   nir_ssa_def *iid0_cond = nir_ieq_imm(&b, iid, 0);
717   nir_cond_end_ir3(&b, iid0_cond);
718
719   emit_tess_epilouge(&b, &state);
720
721   nir_pop_if(&b, nif);
722
723   nir_metadata_preserve(impl, nir_metadata_none);
724}
725
726static void
727lower_tess_eval_block(nir_block *block, nir_builder *b, struct state *state)
728{
729   nir_foreach_instr_safe (instr, block) {
730      if (instr->type != nir_instr_type_intrinsic)
731         continue;
732
733      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
734
735      switch (intr->intrinsic) {
736      case nir_intrinsic_load_tess_coord: {
737         b->cursor = nir_after_instr(&intr->instr);
738         nir_ssa_def *x = nir_channel(b, &intr->dest.ssa, 0);
739         nir_ssa_def *y = nir_channel(b, &intr->dest.ssa, 1);
740         nir_ssa_def *z;
741
742         if (state->topology == IR3_TESS_TRIANGLES)
743            z = nir_fsub(b, nir_fsub(b, nir_imm_float(b, 1.0f), y), x);
744         else
745            z = nir_imm_float(b, 0.0f);
746
747         nir_ssa_def *coord = nir_vec3(b, x, y, z);
748
749         nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, coord,
750                                        b->cursor.instr);
751         break;
752      }
753
754      case nir_intrinsic_load_per_vertex_input: {
755         // src[] = { vertex, offset }.
756
757         b->cursor = nir_before_instr(&intr->instr);
758
759         nir_ssa_def *address = nir_load_tess_param_base_ir3(b);
760         nir_ssa_def *offset = build_per_vertex_offset(
761            b, state, intr->src[0].ssa,
762            nir_intrinsic_io_semantics(intr).location,
763            nir_intrinsic_component(intr), intr->src[1].ssa);
764
765         replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
766                           offset, NULL);
767         break;
768      }
769
770      case nir_intrinsic_load_input: {
771         // src[] = { offset }.
772
773         b->cursor = nir_before_instr(&intr->instr);
774
775         nir_ssa_def *address, *offset;
776
777         /* note if vectorization of the tess level loads ever happens:
778          * "ldg" across 16-byte boundaries can behave incorrectly if results
779          * are never used. most likely some issue with (sy) not properly
780          * syncing with values coming from a second memory transaction.
781          */
782         gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
783         if (is_tess_levels(location)) {
784            assert(intr->dest.ssa.num_components == 1);
785            address = nir_load_tess_factor_base_ir3(b);
786            offset = build_tessfactor_base(b, location, state);
787         } else {
788            address = nir_load_tess_param_base_ir3(b);
789            offset = build_patch_offset(b, state, location,
790                                        nir_intrinsic_component(intr),
791                                        intr->src[0].ssa);
792         }
793
794         offset =
795            nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_component(intr)));
796
797         replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
798                           offset, NULL);
799         break;
800      }
801
802      default:
803         break;
804      }
805   }
806}
807
808void
809ir3_nir_lower_tess_eval(nir_shader *shader, struct ir3_shader_variant *v,
810                        unsigned topology)
811{
812   struct state state = {.topology = topology};
813
814   if (shader_debug_enabled(shader->info.stage)) {
815      mesa_logi("NIR (before tess lowering) for %s shader:",
816                _mesa_shader_stage_to_string(shader->info.stage));
817      nir_log_shaderi(shader);
818   }
819
820   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
821   assert(impl);
822
823   nir_builder b;
824   nir_builder_init(&b, impl);
825
826   nir_foreach_block_safe (block, impl)
827      lower_tess_eval_block(block, &b, &state);
828
829   v->input_size = calc_primitive_map_size(shader);
830
831   nir_metadata_preserve(impl, nir_metadata_none);
832}
833
834static void
835lower_gs_block(nir_block *block, nir_builder *b, struct state *state)
836{
837   nir_foreach_instr_safe (instr, block) {
838      if (instr->type != nir_instr_type_intrinsic)
839         continue;
840
841      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
842
843      switch (intr->intrinsic) {
844      case nir_intrinsic_end_primitive: {
845         /* Note: This ignores the stream, which seems to match the blob
846          * behavior. I'm guessing the HW ignores any extraneous cut
847          * signals from an EndPrimitive() that doesn't correspond to the
848          * rasterized stream.
849          */
850         b->cursor = nir_before_instr(&intr->instr);
851         nir_store_var(b, state->vertex_flags_out, nir_imm_int(b, 4), 0x1);
852         nir_instr_remove(&intr->instr);
853         break;
854      }
855
856      case nir_intrinsic_emit_vertex: {
857         /* Load the vertex count */
858         b->cursor = nir_before_instr(&intr->instr);
859         nir_ssa_def *count = nir_load_var(b, state->vertex_count_var);
860
861         nir_push_if(b, nir_ieq(b, count, local_thread_id(b)));
862
863         unsigned stream = nir_intrinsic_stream_id(intr);
864         /* vertex_flags_out |= stream */
865         nir_store_var(b, state->vertex_flags_out,
866                       nir_ior(b, nir_load_var(b, state->vertex_flags_out),
867                               nir_imm_int(b, stream)),
868                       0x1 /* .x */);
869
870         foreach_two_lists (dest_node, &state->emit_outputs, src_node,
871                            &state->old_outputs) {
872            nir_variable *dest = exec_node_data(nir_variable, dest_node, node);
873            nir_variable *src = exec_node_data(nir_variable, src_node, node);
874            nir_copy_var(b, dest, src);
875         }
876
877         nir_instr_remove(&intr->instr);
878
879         nir_store_var(b, state->emitted_vertex_var,
880                       nir_iadd(b, nir_load_var(b, state->emitted_vertex_var),
881                                nir_imm_int(b, 1)),
882                       0x1);
883
884         nir_pop_if(b, NULL);
885
886         /* Increment the vertex count by 1 */
887         nir_store_var(b, state->vertex_count_var,
888                       nir_iadd(b, count, nir_imm_int(b, 1)), 0x1); /* .x */
889         nir_store_var(b, state->vertex_flags_out, nir_imm_int(b, 0), 0x1);
890
891         break;
892      }
893
894      default:
895         break;
896      }
897   }
898}
899
900void
901ir3_nir_lower_gs(nir_shader *shader)
902{
903   struct state state = {};
904
905   if (shader_debug_enabled(shader->info.stage)) {
906      mesa_logi("NIR (before gs lowering):");
907      nir_log_shaderi(shader);
908   }
909
910   /* Create an output var for vertex_flags. This will be shadowed below,
911    * same way regular outputs get shadowed, and this variable will become a
912    * temporary.
913    */
914   state.vertex_flags_out = nir_variable_create(
915      shader, nir_var_shader_out, glsl_uint_type(), "vertex_flags");
916   state.vertex_flags_out->data.driver_location = shader->num_outputs++;
917   state.vertex_flags_out->data.location = VARYING_SLOT_GS_VERTEX_FLAGS_IR3;
918   state.vertex_flags_out->data.interpolation = INTERP_MODE_NONE;
919
920   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
921   assert(impl);
922
923   nir_builder b;
924   nir_builder_init(&b, impl);
925   b.cursor = nir_before_cf_list(&impl->body);
926
927   state.header = nir_load_gs_header_ir3(&b);
928
929   /* Generate two set of shadow vars for the output variables.  The first
930    * set replaces the real outputs and the second set (emit_outputs) we'll
931    * assign in the emit_vertex conditionals.  Then at the end of the shader
932    * we copy the emit_outputs to the real outputs, so that we get
933    * store_output in uniform control flow.
934    */
935   exec_list_make_empty(&state.old_outputs);
936   nir_foreach_shader_out_variable_safe (var, shader) {
937      exec_node_remove(&var->node);
938      exec_list_push_tail(&state.old_outputs, &var->node);
939   }
940   exec_list_make_empty(&state.new_outputs);
941   exec_list_make_empty(&state.emit_outputs);
942   nir_foreach_variable_in_list (var, &state.old_outputs) {
943      /* Create a new output var by cloning the original output var and
944       * stealing the name.
945       */
946      nir_variable *output = nir_variable_clone(var, shader);
947      exec_list_push_tail(&state.new_outputs, &output->node);
948
949      /* Rewrite the original output to be a shadow variable. */
950      var->name = ralloc_asprintf(var, "%s@gs-temp", output->name);
951      var->data.mode = nir_var_shader_temp;
952
953      /* Clone the shadow variable to create the emit shadow variable that
954       * we'll assign in the emit conditionals.
955       */
956      nir_variable *emit_output = nir_variable_clone(var, shader);
957      emit_output->name = ralloc_asprintf(var, "%s@emit-temp", output->name);
958      exec_list_push_tail(&state.emit_outputs, &emit_output->node);
959   }
960
961   /* During the shader we'll keep track of which vertex we're currently
962    * emitting for the EmitVertex test and how many vertices we emitted so we
963    * know to discard if didn't emit any.  In most simple shaders, this can
964    * all be statically determined and gets optimized away.
965    */
966   state.vertex_count_var =
967      nir_local_variable_create(impl, glsl_uint_type(), "vertex_count");
968   state.emitted_vertex_var =
969      nir_local_variable_create(impl, glsl_uint_type(), "emitted_vertex");
970
971   /* Initialize to 0. */
972   b.cursor = nir_before_cf_list(&impl->body);
973   nir_store_var(&b, state.vertex_count_var, nir_imm_int(&b, 0), 0x1);
974   nir_store_var(&b, state.emitted_vertex_var, nir_imm_int(&b, 0), 0x1);
975   nir_store_var(&b, state.vertex_flags_out, nir_imm_int(&b, 4), 0x1);
976
977   nir_foreach_block_safe (block, impl)
978      lower_gs_block(block, &b, &state);
979
980   set_foreach (impl->end_block->predecessors, block_entry) {
981      struct nir_block *block = (void *)block_entry->key;
982      b.cursor = nir_after_block_before_jump(block);
983
984      nir_ssa_def *cond =
985         nir_ieq_imm(&b, nir_load_var(&b, state.emitted_vertex_var), 0);
986
987      nir_discard_if(&b, cond);
988
989      foreach_two_lists (dest_node, &state.new_outputs, src_node,
990                         &state.emit_outputs) {
991         nir_variable *dest = exec_node_data(nir_variable, dest_node, node);
992         nir_variable *src = exec_node_data(nir_variable, src_node, node);
993         nir_copy_var(&b, dest, src);
994      }
995   }
996
997   exec_list_append(&shader->variables, &state.old_outputs);
998   exec_list_append(&shader->variables, &state.emit_outputs);
999   exec_list_append(&shader->variables, &state.new_outputs);
1000
1001   nir_metadata_preserve(impl, nir_metadata_none);
1002
1003   nir_lower_global_vars_to_local(shader);
1004   nir_split_var_copies(shader);
1005   nir_lower_var_copies(shader);
1006
1007   nir_fixup_deref_modes(shader);
1008
1009   if (shader_debug_enabled(shader->info.stage)) {
1010      mesa_logi("NIR (after gs lowering):");
1011      nir_log_shaderi(shader);
1012   }
1013}
1014