1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "d3d12_compiler.h"
25#include "d3d12_context.h"
26#include "d3d12_debug.h"
27#include "d3d12_screen.h"
28#include "nir_to_dxil.h"
29
30#include "nir.h"
31#include "compiler/nir/nir_builder.h"
32#include "compiler/nir/nir_builtin_builder.h"
33
34#include "util/u_memory.h"
35#include "util/u_simple_shaders.h"
36
37static nir_ssa_def *
38nir_cull_face(nir_builder *b, nir_variable *vertices, bool ccw)
39{
40   nir_ssa_def *v0 =
41       nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, vertices), nir_imm_int(b, 0)));
42   nir_ssa_def *v1 =
43       nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, vertices), nir_imm_int(b, 1)));
44   nir_ssa_def *v2 =
45       nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, vertices), nir_imm_int(b, 2)));
46
47   nir_ssa_def *dir = nir_fdot(b, nir_cross4(b, nir_fsub(b, v1, v0),
48                                               nir_fsub(b, v2, v0)),
49                                   nir_imm_vec4(b, 0.0, 0.0, -1.0, 0.0));
50   if (ccw)
51       return nir_fge(b, nir_imm_int(b, 0), dir);
52   else
53       return nir_flt(b, nir_imm_int(b, 0), dir);
54}
55
56static d3d12_shader_selector*
57d3d12_make_passthrough_gs(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
58{
59   struct d3d12_shader_selector *gs;
60   uint64_t varyings = key->varyings.mask;
61   nir_shader *nir;
62   struct pipe_shader_state templ;
63
64   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_GEOMETRY,
65                                                  dxil_get_nir_compiler_options(),
66                                                  "passthrough");
67
68   nir = b.shader;
69   nir->info.inputs_read = varyings;
70   nir->info.outputs_written = varyings;
71   nir->info.gs.input_primitive = GL_POINTS;
72   nir->info.gs.output_primitive = GL_POINTS;
73   nir->info.gs.vertices_in = 1;
74   nir->info.gs.vertices_out = 1;
75   nir->info.gs.invocations = 1;
76   nir->info.gs.active_stream_mask = 1;
77
78   /* Copy inputs to outputs. */
79   while (varyings) {
80      nir_variable *in, *out;
81      char tmp[100];
82      const int i = u_bit_scan64(&varyings);
83
84      snprintf(tmp, ARRAY_SIZE(tmp), "in_%d", key->varyings.vars[i].driver_location);
85      in = nir_variable_create(nir,
86                               nir_var_shader_in,
87                               glsl_array_type(key->varyings.vars[i].type, 1, false),
88                               tmp);
89      in->data.location = i;
90      in->data.driver_location = key->varyings.vars[i].driver_location;
91      in->data.interpolation = key->varyings.vars[i].interpolation;
92
93      snprintf(tmp, ARRAY_SIZE(tmp), "out_%d", key->varyings.vars[i].driver_location);
94      out = nir_variable_create(nir,
95                                nir_var_shader_out,
96                                key->varyings.vars[i].type,
97                                tmp);
98      out->data.location = i;
99      out->data.driver_location = key->varyings.vars[i].driver_location;
100      out->data.interpolation = key->varyings.vars[i].interpolation;
101
102      nir_deref_instr *in_value = nir_build_deref_array(&b, nir_build_deref_var(&b, in),
103                                                            nir_imm_int(&b, 0));
104      nir_copy_deref(&b, nir_build_deref_var(&b, out), in_value);
105   }
106
107   nir_emit_vertex(&b, 0);
108   nir_end_primitive(&b, 0);
109
110   NIR_PASS_V(nir, nir_lower_var_copies);
111   nir_validate_shader(nir, "in d3d12_create_passthrough_gs");
112
113   templ.type = PIPE_SHADER_IR_NIR;
114   templ.ir.nir = nir;
115   templ.stream_output.num_outputs = 0;
116
117   gs = d3d12_create_shader(ctx, PIPE_SHADER_GEOMETRY, &templ);
118
119   return gs;
120}
121
122struct emit_primitives_context
123{
124   struct d3d12_context *ctx;
125   nir_builder b;
126
127   unsigned num_vars;
128   nir_variable *in[MAX_VARYING];
129   nir_variable *out[MAX_VARYING];
130   nir_variable *front_facing_var;
131
132   nir_loop *loop;
133   nir_deref_instr *loop_index_deref;
134   nir_ssa_def *loop_index;
135   nir_ssa_def *edgeflag_cmp;
136   nir_ssa_def *front_facing;
137};
138
139static bool
140d3d12_begin_emit_primitives_gs(struct emit_primitives_context *emit_ctx,
141                               struct d3d12_context *ctx,
142                               struct d3d12_gs_variant_key *key,
143                               uint16_t output_primitive,
144                               unsigned vertices_out)
145{
146   nir_builder *b = &emit_ctx->b;
147   nir_variable *edgeflag_var = NULL;
148   nir_variable *pos_var = NULL;
149   uint64_t varyings = key->varyings.mask;
150
151   emit_ctx->ctx = ctx;
152
153   emit_ctx->b = nir_builder_init_simple_shader(MESA_SHADER_GEOMETRY,
154                                                dxil_get_nir_compiler_options(),
155                                                "edgeflags");
156
157   nir_shader *nir = b->shader;
158   nir->info.inputs_read = varyings;
159   nir->info.outputs_written = varyings;
160   nir->info.gs.input_primitive = GL_TRIANGLES;
161   nir->info.gs.output_primitive = output_primitive;
162   nir->info.gs.vertices_in = 3;
163   nir->info.gs.vertices_out = vertices_out;
164   nir->info.gs.invocations = 1;
165   nir->info.gs.active_stream_mask = 1;
166
167   while (varyings) {
168      char tmp[100];
169      const int i = u_bit_scan64(&varyings);
170
171      snprintf(tmp, ARRAY_SIZE(tmp), "in_%d", emit_ctx->num_vars);
172      emit_ctx->in[emit_ctx->num_vars] = nir_variable_create(nir,
173                                                             nir_var_shader_in,
174                                                             glsl_array_type(key->varyings.vars[i].type, 3, 0),
175                                                             tmp);
176      emit_ctx->in[emit_ctx->num_vars]->data.location = i;
177      emit_ctx->in[emit_ctx->num_vars]->data.driver_location = key->varyings.vars[i].driver_location;
178      emit_ctx->in[emit_ctx->num_vars]->data.interpolation = key->varyings.vars[i].interpolation;
179
180      /* Don't create an output for the edge flag variable */
181      if (i == VARYING_SLOT_EDGE) {
182         edgeflag_var = emit_ctx->in[emit_ctx->num_vars];
183         continue;
184      } else if (i == VARYING_SLOT_POS) {
185          pos_var = emit_ctx->in[emit_ctx->num_vars];
186      }
187
188      snprintf(tmp, ARRAY_SIZE(tmp), "out_%d", emit_ctx->num_vars);
189      emit_ctx->out[emit_ctx->num_vars] = nir_variable_create(nir,
190                                                              nir_var_shader_out,
191                                                              key->varyings.vars[i].type,
192                                                              tmp);
193      emit_ctx->out[emit_ctx->num_vars]->data.location = i;
194      emit_ctx->out[emit_ctx->num_vars]->data.driver_location = key->varyings.vars[i].driver_location;
195      emit_ctx->out[emit_ctx->num_vars]->data.interpolation = key->varyings.vars[i].interpolation;
196
197      emit_ctx->num_vars++;
198   }
199
200   if (key->has_front_face) {
201      emit_ctx->front_facing_var = nir_variable_create(nir,
202                                                       nir_var_shader_out,
203                                                       glsl_uint_type(),
204                                                       "gl_FrontFacing");
205      emit_ctx->front_facing_var->data.location = VARYING_SLOT_VAR12;
206      emit_ctx->front_facing_var->data.driver_location = emit_ctx->num_vars;
207      emit_ctx->front_facing_var->data.interpolation = INTERP_MODE_FLAT;
208   }
209
210   /* Temporary variable "loop_index" to loop over input vertices */
211   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
212   nir_variable *loop_index_var =
213      nir_local_variable_create(impl, glsl_uint_type(), "loop_index");
214   emit_ctx->loop_index_deref = nir_build_deref_var(b, loop_index_var);
215   nir_store_deref(b, emit_ctx->loop_index_deref, nir_imm_int(b, 0), 1);
216
217   nir_ssa_def *diagonal_vertex = NULL;
218   if (key->edge_flag_fix) {
219      nir_ssa_def *prim_id = nir_load_primitive_id(b);
220      nir_ssa_def *odd = nir_build_alu(b, nir_op_imod,
221                                       prim_id,
222                                       nir_imm_int(b, 2),
223                                       NULL, NULL);
224      diagonal_vertex = nir_bcsel(b, nir_i2b(b, odd),
225                                  nir_imm_int(b, 2),
226                                  nir_imm_int(b, 1));
227   }
228
229   if (key->cull_mode != PIPE_FACE_NONE || key->has_front_face) {
230      if (key->cull_mode == PIPE_FACE_BACK)
231         emit_ctx->edgeflag_cmp = nir_cull_face(b, pos_var, key->front_ccw);
232      else if (key->cull_mode == PIPE_FACE_FRONT)
233         emit_ctx->edgeflag_cmp = nir_cull_face(b, pos_var, !key->front_ccw);
234
235      if (key->has_front_face) {
236         if (key->cull_mode == PIPE_FACE_BACK)
237            emit_ctx->front_facing = emit_ctx->edgeflag_cmp;
238         else
239            emit_ctx->front_facing = nir_cull_face(b, pos_var, key->front_ccw);
240         emit_ctx->front_facing = nir_i2i32(b, emit_ctx->front_facing);
241      }
242   }
243
244   /**
245    *  while {
246    *     if (loop_index >= 3)
247    *        break;
248    */
249   emit_ctx->loop = nir_push_loop(b);
250
251   emit_ctx->loop_index = nir_load_deref(b, emit_ctx->loop_index_deref);
252   nir_ssa_def *cmp = nir_ige(b, emit_ctx->loop_index,
253                              nir_imm_int(b, 3));
254   nir_if *loop_check = nir_push_if(b, cmp);
255   nir_jump(b, nir_jump_break);
256   nir_pop_if(b, loop_check);
257
258   if (edgeflag_var) {
259      nir_ssa_def *edge_flag =
260         nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, edgeflag_var), emit_ctx->loop_index));
261      nir_ssa_def *is_edge = nir_feq(b, nir_channel(b, edge_flag, 0), nir_imm_float(b, 1.0));
262      if (emit_ctx->edgeflag_cmp)
263         emit_ctx->edgeflag_cmp = nir_iand(b, emit_ctx->edgeflag_cmp, is_edge);
264      else
265         emit_ctx->edgeflag_cmp = is_edge;
266   }
267
268   if (key->edge_flag_fix) {
269      nir_ssa_def *is_edge = nir_ine(b, emit_ctx->loop_index, diagonal_vertex);
270      if (emit_ctx->edgeflag_cmp)
271         emit_ctx->edgeflag_cmp = nir_iand(b, emit_ctx->edgeflag_cmp, is_edge);
272      else
273         emit_ctx->edgeflag_cmp = is_edge;
274   }
275
276   return true;
277}
278
279static struct d3d12_shader_selector *
280d3d12_finish_emit_primitives_gs(struct emit_primitives_context *emit_ctx, bool end_primitive)
281{
282   struct pipe_shader_state templ;
283   nir_builder *b = &emit_ctx->b;
284   nir_shader *nir = b->shader;
285
286   /**
287    *     loop_index++;
288    *  }
289    */
290   nir_store_deref(b, emit_ctx->loop_index_deref, nir_iadd_imm(b, emit_ctx->loop_index, 1), 1);
291   nir_pop_loop(b, emit_ctx->loop);
292
293   if (end_primitive)
294      nir_end_primitive(b, 0);
295
296   nir_validate_shader(nir, "in d3d12_lower_edge_flags");
297
298   NIR_PASS_V(nir, nir_lower_var_copies);
299
300   templ.type = PIPE_SHADER_IR_NIR;
301   templ.ir.nir = nir;
302   templ.stream_output.num_outputs = 0;
303
304   return d3d12_create_shader(emit_ctx->ctx, PIPE_SHADER_GEOMETRY, &templ);
305}
306
307static d3d12_shader_selector*
308d3d12_emit_points(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
309{
310   struct emit_primitives_context emit_ctx = {0};
311   nir_builder *b = &emit_ctx.b;
312
313   d3d12_begin_emit_primitives_gs(&emit_ctx, ctx, key, GL_POINTS, 3);
314
315   /**
316    *  if (edge_flag)
317    *     out_position = in_position;
318    *  else
319    *     out_position = vec4(-2.0, -2.0, 0.0, 1.0); // Invalid position
320    *
321    *  [...] // Copy other variables
322    *
323    *  EmitVertex();
324    */
325   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
326      nir_ssa_def *index = (key->flat_varyings & (1ull << emit_ctx.in[i]->data.location))  ?
327                              nir_imm_int(b, (key->flatshade_first ? 0 : 2)) : emit_ctx.loop_index;
328      nir_deref_instr *in_value = nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index);
329      if (emit_ctx.in[i]->data.location == VARYING_SLOT_POS && emit_ctx.edgeflag_cmp) {
330         nir_if *edge_check = nir_push_if(b, emit_ctx.edgeflag_cmp);
331         nir_copy_deref(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
332         nir_if *edge_else = nir_push_else(b, edge_check);
333         nir_store_deref(b, nir_build_deref_var(b, emit_ctx.out[i]),
334                         nir_imm_vec4(b, -2.0, -2.0, 0.0, 1.0), 0xf);
335         nir_pop_if(b, edge_else);
336      } else {
337         nir_copy_deref(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
338      }
339   }
340   if (key->has_front_face)
341       nir_store_var(b, emit_ctx.front_facing_var, emit_ctx.front_facing, 0x1);
342   nir_emit_vertex(b, 0);
343
344   return d3d12_finish_emit_primitives_gs(&emit_ctx, false);
345}
346
347static d3d12_shader_selector*
348d3d12_emit_lines(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
349{
350   struct emit_primitives_context emit_ctx = {0};
351   nir_builder *b = &emit_ctx.b;
352
353   d3d12_begin_emit_primitives_gs(&emit_ctx, ctx, key, GL_LINE_STRIP, 6);
354
355   nir_ssa_def *next_index = nir_imod(b, nir_iadd_imm(b, emit_ctx.loop_index, 1), nir_imm_int(b, 3));
356
357   /* First vertex */
358   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
359      nir_ssa_def *index = (key->flat_varyings & (1ull << emit_ctx.in[i]->data.location)) ?
360                              nir_imm_int(b, (key->flatshade_first ? 0 : 2)) : emit_ctx.loop_index;
361      nir_deref_instr *in_value = nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index);
362      nir_copy_deref(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
363   }
364   if (key->has_front_face)
365       nir_store_var(b, emit_ctx.front_facing_var, emit_ctx.front_facing, 0x1);
366   nir_emit_vertex(b, 0);
367
368   /* Second vertex. If not an edge, use same position as first vertex */
369   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
370      nir_ssa_def *index = next_index;
371      if (emit_ctx.in[i]->data.location == VARYING_SLOT_POS)
372         index = nir_bcsel(b, emit_ctx.edgeflag_cmp, next_index, emit_ctx.loop_index);
373      else if (key->flat_varyings & (1ull << emit_ctx.in[i]->data.location))
374         index = nir_imm_int(b, 2);
375      nir_copy_deref(b, nir_build_deref_var(b, emit_ctx.out[i]),
376                     nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index));
377   }
378   if (key->has_front_face)
379       nir_store_var(b, emit_ctx.front_facing_var, emit_ctx.front_facing, 0x1);
380   nir_emit_vertex(b, 0);
381
382   nir_end_primitive(b, 0);
383
384   return d3d12_finish_emit_primitives_gs(&emit_ctx, false);
385}
386
387static d3d12_shader_selector*
388d3d12_emit_triangles(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
389{
390   struct emit_primitives_context emit_ctx = {0};
391   nir_builder *b = &emit_ctx.b;
392
393   d3d12_begin_emit_primitives_gs(&emit_ctx, ctx, key, GL_TRIANGLE_STRIP, 3);
394
395   /**
396    *  [...] // Copy variables
397    *
398    *  EmitVertex();
399    */
400
401   nir_ssa_def *incr = NULL;
402
403   if (key->provoking_vertex > 0)
404      incr = nir_imm_int(b, key->provoking_vertex);
405   else
406      incr = nir_imm_int(b, 3);
407
408   if (key->alternate_tri) {
409      nir_ssa_def *odd = nir_imod(b, nir_load_primitive_id(b), nir_imm_int(b, 2));
410      incr = nir_isub(b, incr, odd);
411   }
412
413   assert(incr != NULL);
414   nir_ssa_def *index = nir_imod(b, nir_iadd(b, emit_ctx.loop_index, incr), nir_imm_int(b, 3));
415   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
416      nir_deref_instr *in_value = nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index);
417      nir_copy_deref(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
418   }
419   nir_emit_vertex(b, 0);
420
421   return d3d12_finish_emit_primitives_gs(&emit_ctx, true);
422}
423
424static uint32_t
425hash_gs_variant_key(const void *key)
426{
427   return _mesa_hash_data(key, sizeof(struct d3d12_gs_variant_key));
428}
429
430static bool
431equals_gs_variant_key(const void *a, const void *b)
432{
433   return memcmp(a, b, sizeof(struct d3d12_gs_variant_key)) == 0;
434}
435
436void
437d3d12_gs_variant_cache_init(struct d3d12_context *ctx)
438{
439   ctx->gs_variant_cache = _mesa_hash_table_create(NULL, NULL, equals_gs_variant_key);
440}
441
442static void
443delete_entry(struct hash_entry *entry)
444{
445   d3d12_shader_free((d3d12_shader_selector *)entry->data);
446}
447
448void
449d3d12_gs_variant_cache_destroy(struct d3d12_context *ctx)
450{
451   _mesa_hash_table_destroy(ctx->gs_variant_cache, delete_entry);
452}
453
454static struct d3d12_shader_selector *
455create_geometry_shader_variant(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
456{
457   d3d12_shader_selector *gs = NULL;
458
459   if (key->passthrough)
460      gs = d3d12_make_passthrough_gs(ctx, key);
461   else if (key->provoking_vertex > 0 || key->alternate_tri)
462      gs = d3d12_emit_triangles(ctx, key);
463   else if (key->fill_mode == PIPE_POLYGON_MODE_POINT)
464      gs = d3d12_emit_points(ctx, key);
465   else if (key->fill_mode == PIPE_POLYGON_MODE_LINE)
466      gs = d3d12_emit_lines(ctx, key);
467
468   if (gs) {
469      gs->is_gs_variant = true;
470      gs->gs_key = *key;
471   }
472
473   return gs;
474}
475
476d3d12_shader_selector *
477d3d12_get_gs_variant(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
478{
479   uint32_t hash = hash_gs_variant_key(key);
480   struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->gs_variant_cache,
481                                                                 hash, key);
482   if (!entry) {
483      d3d12_shader_selector *gs = create_geometry_shader_variant(ctx, key);
484      entry = _mesa_hash_table_insert_pre_hashed(ctx->gs_variant_cache,
485                                                 hash, &gs->gs_key, gs);
486      assert(entry);
487   }
488
489   return (d3d12_shader_selector *)entry->data;
490}
491