1/*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "vtn_private.h"
25
26static void
27vtn_build_subgroup_instr(struct vtn_builder *b,
28                         nir_intrinsic_op nir_op,
29                         struct vtn_ssa_value *dst,
30                         struct vtn_ssa_value *src0,
31                         nir_ssa_def *index,
32                         unsigned const_idx0,
33                         unsigned const_idx1)
34{
35   /* Some of the subgroup operations take an index.  SPIR-V allows this to be
36    * any integer type.  To make things simpler for drivers, we only support
37    * 32-bit indices.
38    */
39   if (index && index->bit_size != 32)
40      index = nir_u2u32(&b->nb, index);
41
42   vtn_assert(dst->type == src0->type);
43   if (!glsl_type_is_vector_or_scalar(dst->type)) {
44      for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
45         vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
46                                  src0->elems[i], index,
47                                  const_idx0, const_idx1);
48      }
49      return;
50   }
51
52   nir_intrinsic_instr *intrin =
53      nir_intrinsic_instr_create(b->nb.shader, nir_op);
54   nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
55                              dst->type, NULL);
56   intrin->num_components = intrin->dest.ssa.num_components;
57
58   intrin->src[0] = nir_src_for_ssa(src0->def);
59   if (index)
60      intrin->src[1] = nir_src_for_ssa(index);
61
62   intrin->const_index[0] = const_idx0;
63   intrin->const_index[1] = const_idx1;
64
65   nir_builder_instr_insert(&b->nb, &intrin->instr);
66
67   dst->def = &intrin->dest.ssa;
68}
69
70void
71vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
72                    const uint32_t *w, unsigned count)
73{
74   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
75
76   val->ssa = vtn_create_ssa_value(b, val->type->type);
77
78   switch (opcode) {
79   case SpvOpGroupNonUniformElect: {
80      vtn_fail_if(val->type->type != glsl_bool_type(),
81                  "OpGroupNonUniformElect must return a Bool");
82      nir_intrinsic_instr *elect =
83         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84      nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
85                                 val->type->type, NULL);
86      nir_builder_instr_insert(&b->nb, &elect->instr);
87      val->ssa->def = &elect->dest.ssa;
88      break;
89   }
90
91   case SpvOpGroupNonUniformBallot: {
92      vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
93                  "OpGroupNonUniformBallot must return a uvec4");
94      nir_intrinsic_instr *ballot =
95         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
96      ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
97      nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
98      ballot->num_components = 4;
99      nir_builder_instr_insert(&b->nb, &ballot->instr);
100      val->ssa->def = &ballot->dest.ssa;
101      break;
102   }
103
104   case SpvOpGroupNonUniformInverseBallot: {
105      /* This one is just a BallotBitfieldExtract with subgroup invocation.
106       * We could add a NIR intrinsic but it's easier to just lower it on the
107       * spot.
108       */
109      nir_intrinsic_instr *intrin =
110         nir_intrinsic_instr_create(b->nb.shader,
111                                    nir_intrinsic_ballot_bitfield_extract);
112
113      intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
114      intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
115
116      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
117                                 val->type->type, NULL);
118      nir_builder_instr_insert(&b->nb, &intrin->instr);
119
120      val->ssa->def = &intrin->dest.ssa;
121      break;
122   }
123
124   case SpvOpGroupNonUniformBallotBitExtract:
125   case SpvOpGroupNonUniformBallotBitCount:
126   case SpvOpGroupNonUniformBallotFindLSB:
127   case SpvOpGroupNonUniformBallotFindMSB: {
128      nir_ssa_def *src0, *src1 = NULL;
129      nir_intrinsic_op op;
130      switch (opcode) {
131      case SpvOpGroupNonUniformBallotBitExtract:
132         op = nir_intrinsic_ballot_bitfield_extract;
133         src0 = vtn_ssa_value(b, w[4])->def;
134         src1 = vtn_ssa_value(b, w[5])->def;
135         break;
136      case SpvOpGroupNonUniformBallotBitCount:
137         switch ((SpvGroupOperation)w[4]) {
138         case SpvGroupOperationReduce:
139            op = nir_intrinsic_ballot_bit_count_reduce;
140            break;
141         case SpvGroupOperationInclusiveScan:
142            op = nir_intrinsic_ballot_bit_count_inclusive;
143            break;
144         case SpvGroupOperationExclusiveScan:
145            op = nir_intrinsic_ballot_bit_count_exclusive;
146            break;
147         default:
148            unreachable("Invalid group operation");
149         }
150         src0 = vtn_ssa_value(b, w[5])->def;
151         break;
152      case SpvOpGroupNonUniformBallotFindLSB:
153         op = nir_intrinsic_ballot_find_lsb;
154         src0 = vtn_ssa_value(b, w[4])->def;
155         break;
156      case SpvOpGroupNonUniformBallotFindMSB:
157         op = nir_intrinsic_ballot_find_msb;
158         src0 = vtn_ssa_value(b, w[4])->def;
159         break;
160      default:
161         unreachable("Unhandled opcode");
162      }
163
164      nir_intrinsic_instr *intrin =
165         nir_intrinsic_instr_create(b->nb.shader, op);
166
167      intrin->src[0] = nir_src_for_ssa(src0);
168      if (src1)
169         intrin->src[1] = nir_src_for_ssa(src1);
170
171      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
172                                 val->type->type, NULL);
173      nir_builder_instr_insert(&b->nb, &intrin->instr);
174
175      val->ssa->def = &intrin->dest.ssa;
176      break;
177   }
178
179   case SpvOpGroupNonUniformBroadcastFirst:
180      vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
181                               val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
182      break;
183
184   case SpvOpGroupNonUniformBroadcast:
185      vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
186                               val->ssa, vtn_ssa_value(b, w[4]),
187                               vtn_ssa_value(b, w[5])->def, 0, 0);
188      break;
189
190   case SpvOpGroupNonUniformAll:
191   case SpvOpGroupNonUniformAny:
192   case SpvOpGroupNonUniformAllEqual: {
193      vtn_fail_if(val->type->type != glsl_bool_type(),
194                  "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
195      nir_intrinsic_op op;
196      switch (opcode) {
197      case SpvOpGroupNonUniformAll:
198         op = nir_intrinsic_vote_all;
199         break;
200      case SpvOpGroupNonUniformAny:
201         op = nir_intrinsic_vote_any;
202         break;
203      case SpvOpGroupNonUniformAllEqual: {
204         switch (glsl_get_base_type(val->type->type)) {
205         case GLSL_TYPE_FLOAT:
206         case GLSL_TYPE_DOUBLE:
207            op = nir_intrinsic_vote_feq;
208            break;
209         case GLSL_TYPE_UINT:
210         case GLSL_TYPE_INT:
211         case GLSL_TYPE_UINT64:
212         case GLSL_TYPE_INT64:
213         case GLSL_TYPE_BOOL:
214            op = nir_intrinsic_vote_ieq;
215            break;
216         default:
217            unreachable("Unhandled type");
218         }
219         break;
220      }
221      default:
222         unreachable("Unhandled opcode");
223      }
224
225      nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def;
226
227      nir_intrinsic_instr *intrin =
228         nir_intrinsic_instr_create(b->nb.shader, op);
229      intrin->num_components = src0->num_components;
230      intrin->src[0] = nir_src_for_ssa(src0);
231      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
232                                 val->type->type, NULL);
233      nir_builder_instr_insert(&b->nb, &intrin->instr);
234
235      val->ssa->def = &intrin->dest.ssa;
236      break;
237   }
238
239   case SpvOpGroupNonUniformShuffle:
240   case SpvOpGroupNonUniformShuffleXor:
241   case SpvOpGroupNonUniformShuffleUp:
242   case SpvOpGroupNonUniformShuffleDown: {
243      nir_intrinsic_op op;
244      switch (opcode) {
245      case SpvOpGroupNonUniformShuffle:
246         op = nir_intrinsic_shuffle;
247         break;
248      case SpvOpGroupNonUniformShuffleXor:
249         op = nir_intrinsic_shuffle_xor;
250         break;
251      case SpvOpGroupNonUniformShuffleUp:
252         op = nir_intrinsic_shuffle_up;
253         break;
254      case SpvOpGroupNonUniformShuffleDown:
255         op = nir_intrinsic_shuffle_down;
256         break;
257      default:
258         unreachable("Invalid opcode");
259      }
260      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
261                               vtn_ssa_value(b, w[5])->def, 0, 0);
262      break;
263   }
264
265   case SpvOpGroupNonUniformQuadBroadcast:
266      vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
267                               val->ssa, vtn_ssa_value(b, w[4]),
268                               vtn_ssa_value(b, w[5])->def, 0, 0);
269      break;
270
271   case SpvOpGroupNonUniformQuadSwap: {
272      unsigned direction = vtn_constant_uint(b, w[5]);
273      nir_intrinsic_op op;
274      switch (direction) {
275      case 0:
276         op = nir_intrinsic_quad_swap_horizontal;
277         break;
278      case 1:
279         op = nir_intrinsic_quad_swap_vertical;
280         break;
281      case 2:
282         op = nir_intrinsic_quad_swap_diagonal;
283         break;
284      default:
285         vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
286      }
287      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
288                               NULL, 0, 0);
289      break;
290   }
291
292   case SpvOpGroupNonUniformIAdd:
293   case SpvOpGroupNonUniformFAdd:
294   case SpvOpGroupNonUniformIMul:
295   case SpvOpGroupNonUniformFMul:
296   case SpvOpGroupNonUniformSMin:
297   case SpvOpGroupNonUniformUMin:
298   case SpvOpGroupNonUniformFMin:
299   case SpvOpGroupNonUniformSMax:
300   case SpvOpGroupNonUniformUMax:
301   case SpvOpGroupNonUniformFMax:
302   case SpvOpGroupNonUniformBitwiseAnd:
303   case SpvOpGroupNonUniformBitwiseOr:
304   case SpvOpGroupNonUniformBitwiseXor:
305   case SpvOpGroupNonUniformLogicalAnd:
306   case SpvOpGroupNonUniformLogicalOr:
307   case SpvOpGroupNonUniformLogicalXor: {
308      nir_op reduction_op;
309      switch (opcode) {
310      case SpvOpGroupNonUniformIAdd:
311         reduction_op = nir_op_iadd;
312         break;
313      case SpvOpGroupNonUniformFAdd:
314         reduction_op = nir_op_fadd;
315         break;
316      case SpvOpGroupNonUniformIMul:
317         reduction_op = nir_op_imul;
318         break;
319      case SpvOpGroupNonUniformFMul:
320         reduction_op = nir_op_fmul;
321         break;
322      case SpvOpGroupNonUniformSMin:
323         reduction_op = nir_op_imin;
324         break;
325      case SpvOpGroupNonUniformUMin:
326         reduction_op = nir_op_umin;
327         break;
328      case SpvOpGroupNonUniformFMin:
329         reduction_op = nir_op_fmin;
330         break;
331      case SpvOpGroupNonUniformSMax:
332         reduction_op = nir_op_imax;
333         break;
334      case SpvOpGroupNonUniformUMax:
335         reduction_op = nir_op_umax;
336         break;
337      case SpvOpGroupNonUniformFMax:
338         reduction_op = nir_op_fmax;
339         break;
340      case SpvOpGroupNonUniformBitwiseAnd:
341      case SpvOpGroupNonUniformLogicalAnd:
342         reduction_op = nir_op_iand;
343         break;
344      case SpvOpGroupNonUniformBitwiseOr:
345      case SpvOpGroupNonUniformLogicalOr:
346         reduction_op = nir_op_ior;
347         break;
348      case SpvOpGroupNonUniformBitwiseXor:
349      case SpvOpGroupNonUniformLogicalXor:
350         reduction_op = nir_op_ixor;
351         break;
352      default:
353         unreachable("Invalid reduction operation");
354      }
355
356      nir_intrinsic_op op;
357      unsigned cluster_size = 0;
358      switch ((SpvGroupOperation)w[4]) {
359      case SpvGroupOperationReduce:
360         op = nir_intrinsic_reduce;
361         break;
362      case SpvGroupOperationInclusiveScan:
363         op = nir_intrinsic_inclusive_scan;
364         break;
365      case SpvGroupOperationExclusiveScan:
366         op = nir_intrinsic_exclusive_scan;
367         break;
368      case SpvGroupOperationClusteredReduce:
369         op = nir_intrinsic_reduce;
370         assert(count == 7);
371         cluster_size = vtn_constant_uint(b, w[6]);
372         break;
373      default:
374         unreachable("Invalid group operation");
375      }
376
377      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
378                               NULL, reduction_op, cluster_size);
379      break;
380   }
381
382   default:
383      unreachable("Invalid SPIR-V opcode");
384   }
385}
386