1b8e80941Smrg/*
2b8e80941Smrg * Copyright © 2016 Intel Corporation
3b8e80941Smrg *
4b8e80941Smrg * Permission is hereby granted, free of charge, to any person obtaining a
5b8e80941Smrg * copy of this software and associated documentation files (the "Software"),
6b8e80941Smrg * to deal in the Software without restriction, including without limitation
7b8e80941Smrg * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8b8e80941Smrg * and/or sell copies of the Software, and to permit persons to whom the
9b8e80941Smrg * Software is furnished to do so, subject to the following conditions:
10b8e80941Smrg *
11b8e80941Smrg * The above copyright notice and this permission notice (including the next
12b8e80941Smrg * paragraph) shall be included in all copies or substantial portions of the
13b8e80941Smrg * Software.
14b8e80941Smrg *
15b8e80941Smrg * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16b8e80941Smrg * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17b8e80941Smrg * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18b8e80941Smrg * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19b8e80941Smrg * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20b8e80941Smrg * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21b8e80941Smrg * IN THE SOFTWARE.
22b8e80941Smrg */
23b8e80941Smrg
24b8e80941Smrg#include "vtn_private.h"
25b8e80941Smrg
26b8e80941Smrgstatic void
27b8e80941Smrgvtn_build_subgroup_instr(struct vtn_builder *b,
28b8e80941Smrg                         nir_intrinsic_op nir_op,
29b8e80941Smrg                         struct vtn_ssa_value *dst,
30b8e80941Smrg                         struct vtn_ssa_value *src0,
31b8e80941Smrg                         nir_ssa_def *index,
32b8e80941Smrg                         unsigned const_idx0,
33b8e80941Smrg                         unsigned const_idx1)
34b8e80941Smrg{
35b8e80941Smrg   /* Some of the subgroup operations take an index.  SPIR-V allows this to be
36b8e80941Smrg    * any integer type.  To make things simpler for drivers, we only support
37b8e80941Smrg    * 32-bit indices.
38b8e80941Smrg    */
39b8e80941Smrg   if (index && index->bit_size != 32)
40b8e80941Smrg      index = nir_u2u32(&b->nb, index);
41b8e80941Smrg
42b8e80941Smrg   vtn_assert(dst->type == src0->type);
43b8e80941Smrg   if (!glsl_type_is_vector_or_scalar(dst->type)) {
44b8e80941Smrg      for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
45b8e80941Smrg         vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
46b8e80941Smrg                                  src0->elems[i], index,
47b8e80941Smrg                                  const_idx0, const_idx1);
48b8e80941Smrg      }
49b8e80941Smrg      return;
50b8e80941Smrg   }
51b8e80941Smrg
52b8e80941Smrg   nir_intrinsic_instr *intrin =
53b8e80941Smrg      nir_intrinsic_instr_create(b->nb.shader, nir_op);
54b8e80941Smrg   nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
55b8e80941Smrg                              dst->type, NULL);
56b8e80941Smrg   intrin->num_components = intrin->dest.ssa.num_components;
57b8e80941Smrg
58b8e80941Smrg   intrin->src[0] = nir_src_for_ssa(src0->def);
59b8e80941Smrg   if (index)
60b8e80941Smrg      intrin->src[1] = nir_src_for_ssa(index);
61b8e80941Smrg
62b8e80941Smrg   intrin->const_index[0] = const_idx0;
63b8e80941Smrg   intrin->const_index[1] = const_idx1;
64b8e80941Smrg
65b8e80941Smrg   nir_builder_instr_insert(&b->nb, &intrin->instr);
66b8e80941Smrg
67b8e80941Smrg   dst->def = &intrin->dest.ssa;
68b8e80941Smrg}
69b8e80941Smrg
70b8e80941Smrgvoid
71b8e80941Smrgvtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
72b8e80941Smrg                    const uint32_t *w, unsigned count)
73b8e80941Smrg{
74b8e80941Smrg   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
75b8e80941Smrg
76b8e80941Smrg   val->ssa = vtn_create_ssa_value(b, val->type->type);
77b8e80941Smrg
78b8e80941Smrg   switch (opcode) {
79b8e80941Smrg   case SpvOpGroupNonUniformElect: {
80b8e80941Smrg      vtn_fail_if(val->type->type != glsl_bool_type(),
81b8e80941Smrg                  "OpGroupNonUniformElect must return a Bool");
82b8e80941Smrg      nir_intrinsic_instr *elect =
83b8e80941Smrg         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84b8e80941Smrg      nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
85b8e80941Smrg                                 val->type->type, NULL);
86b8e80941Smrg      nir_builder_instr_insert(&b->nb, &elect->instr);
87b8e80941Smrg      val->ssa->def = &elect->dest.ssa;
88b8e80941Smrg      break;
89b8e80941Smrg   }
90b8e80941Smrg
91b8e80941Smrg   case SpvOpGroupNonUniformBallot: {
92b8e80941Smrg      vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
93b8e80941Smrg                  "OpGroupNonUniformBallot must return a uvec4");
94b8e80941Smrg      nir_intrinsic_instr *ballot =
95b8e80941Smrg         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
96b8e80941Smrg      ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
97b8e80941Smrg      nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
98b8e80941Smrg      ballot->num_components = 4;
99b8e80941Smrg      nir_builder_instr_insert(&b->nb, &ballot->instr);
100b8e80941Smrg      val->ssa->def = &ballot->dest.ssa;
101b8e80941Smrg      break;
102b8e80941Smrg   }
103b8e80941Smrg
104b8e80941Smrg   case SpvOpGroupNonUniformInverseBallot: {
105b8e80941Smrg      /* This one is just a BallotBitfieldExtract with subgroup invocation.
106b8e80941Smrg       * We could add a NIR intrinsic but it's easier to just lower it on the
107b8e80941Smrg       * spot.
108b8e80941Smrg       */
109b8e80941Smrg      nir_intrinsic_instr *intrin =
110b8e80941Smrg         nir_intrinsic_instr_create(b->nb.shader,
111b8e80941Smrg                                    nir_intrinsic_ballot_bitfield_extract);
112b8e80941Smrg
113b8e80941Smrg      intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
114b8e80941Smrg      intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
115b8e80941Smrg
116b8e80941Smrg      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
117b8e80941Smrg                                 val->type->type, NULL);
118b8e80941Smrg      nir_builder_instr_insert(&b->nb, &intrin->instr);
119b8e80941Smrg
120b8e80941Smrg      val->ssa->def = &intrin->dest.ssa;
121b8e80941Smrg      break;
122b8e80941Smrg   }
123b8e80941Smrg
124b8e80941Smrg   case SpvOpGroupNonUniformBallotBitExtract:
125b8e80941Smrg   case SpvOpGroupNonUniformBallotBitCount:
126b8e80941Smrg   case SpvOpGroupNonUniformBallotFindLSB:
127b8e80941Smrg   case SpvOpGroupNonUniformBallotFindMSB: {
128b8e80941Smrg      nir_ssa_def *src0, *src1 = NULL;
129b8e80941Smrg      nir_intrinsic_op op;
130b8e80941Smrg      switch (opcode) {
131b8e80941Smrg      case SpvOpGroupNonUniformBallotBitExtract:
132b8e80941Smrg         op = nir_intrinsic_ballot_bitfield_extract;
133b8e80941Smrg         src0 = vtn_ssa_value(b, w[4])->def;
134b8e80941Smrg         src1 = vtn_ssa_value(b, w[5])->def;
135b8e80941Smrg         break;
136b8e80941Smrg      case SpvOpGroupNonUniformBallotBitCount:
137b8e80941Smrg         switch ((SpvGroupOperation)w[4]) {
138b8e80941Smrg         case SpvGroupOperationReduce:
139b8e80941Smrg            op = nir_intrinsic_ballot_bit_count_reduce;
140b8e80941Smrg            break;
141b8e80941Smrg         case SpvGroupOperationInclusiveScan:
142b8e80941Smrg            op = nir_intrinsic_ballot_bit_count_inclusive;
143b8e80941Smrg            break;
144b8e80941Smrg         case SpvGroupOperationExclusiveScan:
145b8e80941Smrg            op = nir_intrinsic_ballot_bit_count_exclusive;
146b8e80941Smrg            break;
147b8e80941Smrg         default:
148b8e80941Smrg            unreachable("Invalid group operation");
149b8e80941Smrg         }
150b8e80941Smrg         src0 = vtn_ssa_value(b, w[5])->def;
151b8e80941Smrg         break;
152b8e80941Smrg      case SpvOpGroupNonUniformBallotFindLSB:
153b8e80941Smrg         op = nir_intrinsic_ballot_find_lsb;
154b8e80941Smrg         src0 = vtn_ssa_value(b, w[4])->def;
155b8e80941Smrg         break;
156b8e80941Smrg      case SpvOpGroupNonUniformBallotFindMSB:
157b8e80941Smrg         op = nir_intrinsic_ballot_find_msb;
158b8e80941Smrg         src0 = vtn_ssa_value(b, w[4])->def;
159b8e80941Smrg         break;
160b8e80941Smrg      default:
161b8e80941Smrg         unreachable("Unhandled opcode");
162b8e80941Smrg      }
163b8e80941Smrg
164b8e80941Smrg      nir_intrinsic_instr *intrin =
165b8e80941Smrg         nir_intrinsic_instr_create(b->nb.shader, op);
166b8e80941Smrg
167b8e80941Smrg      intrin->src[0] = nir_src_for_ssa(src0);
168b8e80941Smrg      if (src1)
169b8e80941Smrg         intrin->src[1] = nir_src_for_ssa(src1);
170b8e80941Smrg
171b8e80941Smrg      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
172b8e80941Smrg                                 val->type->type, NULL);
173b8e80941Smrg      nir_builder_instr_insert(&b->nb, &intrin->instr);
174b8e80941Smrg
175b8e80941Smrg      val->ssa->def = &intrin->dest.ssa;
176b8e80941Smrg      break;
177b8e80941Smrg   }
178b8e80941Smrg
179b8e80941Smrg   case SpvOpGroupNonUniformBroadcastFirst:
180b8e80941Smrg      vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
181b8e80941Smrg                               val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
182b8e80941Smrg      break;
183b8e80941Smrg
184b8e80941Smrg   case SpvOpGroupNonUniformBroadcast:
185b8e80941Smrg      vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
186b8e80941Smrg                               val->ssa, vtn_ssa_value(b, w[4]),
187b8e80941Smrg                               vtn_ssa_value(b, w[5])->def, 0, 0);
188b8e80941Smrg      break;
189b8e80941Smrg
190b8e80941Smrg   case SpvOpGroupNonUniformAll:
191b8e80941Smrg   case SpvOpGroupNonUniformAny:
192b8e80941Smrg   case SpvOpGroupNonUniformAllEqual: {
193b8e80941Smrg      vtn_fail_if(val->type->type != glsl_bool_type(),
194b8e80941Smrg                  "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
195b8e80941Smrg      nir_intrinsic_op op;
196b8e80941Smrg      switch (opcode) {
197b8e80941Smrg      case SpvOpGroupNonUniformAll:
198b8e80941Smrg         op = nir_intrinsic_vote_all;
199b8e80941Smrg         break;
200b8e80941Smrg      case SpvOpGroupNonUniformAny:
201b8e80941Smrg         op = nir_intrinsic_vote_any;
202b8e80941Smrg         break;
203b8e80941Smrg      case SpvOpGroupNonUniformAllEqual: {
204b8e80941Smrg         switch (glsl_get_base_type(val->type->type)) {
205b8e80941Smrg         case GLSL_TYPE_FLOAT:
206b8e80941Smrg         case GLSL_TYPE_DOUBLE:
207b8e80941Smrg            op = nir_intrinsic_vote_feq;
208b8e80941Smrg            break;
209b8e80941Smrg         case GLSL_TYPE_UINT:
210b8e80941Smrg         case GLSL_TYPE_INT:
211b8e80941Smrg         case GLSL_TYPE_UINT64:
212b8e80941Smrg         case GLSL_TYPE_INT64:
213b8e80941Smrg         case GLSL_TYPE_BOOL:
214b8e80941Smrg            op = nir_intrinsic_vote_ieq;
215b8e80941Smrg            break;
216b8e80941Smrg         default:
217b8e80941Smrg            unreachable("Unhandled type");
218b8e80941Smrg         }
219b8e80941Smrg         break;
220b8e80941Smrg      }
221b8e80941Smrg      default:
222b8e80941Smrg         unreachable("Unhandled opcode");
223b8e80941Smrg      }
224b8e80941Smrg
225b8e80941Smrg      nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def;
226b8e80941Smrg
227b8e80941Smrg      nir_intrinsic_instr *intrin =
228b8e80941Smrg         nir_intrinsic_instr_create(b->nb.shader, op);
229b8e80941Smrg      intrin->num_components = src0->num_components;
230b8e80941Smrg      intrin->src[0] = nir_src_for_ssa(src0);
231b8e80941Smrg      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
232b8e80941Smrg                                 val->type->type, NULL);
233b8e80941Smrg      nir_builder_instr_insert(&b->nb, &intrin->instr);
234b8e80941Smrg
235b8e80941Smrg      val->ssa->def = &intrin->dest.ssa;
236b8e80941Smrg      break;
237b8e80941Smrg   }
238b8e80941Smrg
239b8e80941Smrg   case SpvOpGroupNonUniformShuffle:
240b8e80941Smrg   case SpvOpGroupNonUniformShuffleXor:
241b8e80941Smrg   case SpvOpGroupNonUniformShuffleUp:
242b8e80941Smrg   case SpvOpGroupNonUniformShuffleDown: {
243b8e80941Smrg      nir_intrinsic_op op;
244b8e80941Smrg      switch (opcode) {
245b8e80941Smrg      case SpvOpGroupNonUniformShuffle:
246b8e80941Smrg         op = nir_intrinsic_shuffle;
247b8e80941Smrg         break;
248b8e80941Smrg      case SpvOpGroupNonUniformShuffleXor:
249b8e80941Smrg         op = nir_intrinsic_shuffle_xor;
250b8e80941Smrg         break;
251b8e80941Smrg      case SpvOpGroupNonUniformShuffleUp:
252b8e80941Smrg         op = nir_intrinsic_shuffle_up;
253b8e80941Smrg         break;
254b8e80941Smrg      case SpvOpGroupNonUniformShuffleDown:
255b8e80941Smrg         op = nir_intrinsic_shuffle_down;
256b8e80941Smrg         break;
257b8e80941Smrg      default:
258b8e80941Smrg         unreachable("Invalid opcode");
259b8e80941Smrg      }
260b8e80941Smrg      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
261b8e80941Smrg                               vtn_ssa_value(b, w[5])->def, 0, 0);
262b8e80941Smrg      break;
263b8e80941Smrg   }
264b8e80941Smrg
265b8e80941Smrg   case SpvOpGroupNonUniformQuadBroadcast:
266b8e80941Smrg      vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
267b8e80941Smrg                               val->ssa, vtn_ssa_value(b, w[4]),
268b8e80941Smrg                               vtn_ssa_value(b, w[5])->def, 0, 0);
269b8e80941Smrg      break;
270b8e80941Smrg
271b8e80941Smrg   case SpvOpGroupNonUniformQuadSwap: {
272b8e80941Smrg      unsigned direction = vtn_constant_uint(b, w[5]);
273b8e80941Smrg      nir_intrinsic_op op;
274b8e80941Smrg      switch (direction) {
275b8e80941Smrg      case 0:
276b8e80941Smrg         op = nir_intrinsic_quad_swap_horizontal;
277b8e80941Smrg         break;
278b8e80941Smrg      case 1:
279b8e80941Smrg         op = nir_intrinsic_quad_swap_vertical;
280b8e80941Smrg         break;
281b8e80941Smrg      case 2:
282b8e80941Smrg         op = nir_intrinsic_quad_swap_diagonal;
283b8e80941Smrg         break;
284b8e80941Smrg      default:
285b8e80941Smrg         vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
286b8e80941Smrg      }
287b8e80941Smrg      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
288b8e80941Smrg                               NULL, 0, 0);
289b8e80941Smrg      break;
290b8e80941Smrg   }
291b8e80941Smrg
292b8e80941Smrg   case SpvOpGroupNonUniformIAdd:
293b8e80941Smrg   case SpvOpGroupNonUniformFAdd:
294b8e80941Smrg   case SpvOpGroupNonUniformIMul:
295b8e80941Smrg   case SpvOpGroupNonUniformFMul:
296b8e80941Smrg   case SpvOpGroupNonUniformSMin:
297b8e80941Smrg   case SpvOpGroupNonUniformUMin:
298b8e80941Smrg   case SpvOpGroupNonUniformFMin:
299b8e80941Smrg   case SpvOpGroupNonUniformSMax:
300b8e80941Smrg   case SpvOpGroupNonUniformUMax:
301b8e80941Smrg   case SpvOpGroupNonUniformFMax:
302b8e80941Smrg   case SpvOpGroupNonUniformBitwiseAnd:
303b8e80941Smrg   case SpvOpGroupNonUniformBitwiseOr:
304b8e80941Smrg   case SpvOpGroupNonUniformBitwiseXor:
305b8e80941Smrg   case SpvOpGroupNonUniformLogicalAnd:
306b8e80941Smrg   case SpvOpGroupNonUniformLogicalOr:
307b8e80941Smrg   case SpvOpGroupNonUniformLogicalXor: {
308b8e80941Smrg      nir_op reduction_op;
309b8e80941Smrg      switch (opcode) {
310b8e80941Smrg      case SpvOpGroupNonUniformIAdd:
311b8e80941Smrg         reduction_op = nir_op_iadd;
312b8e80941Smrg         break;
313b8e80941Smrg      case SpvOpGroupNonUniformFAdd:
314b8e80941Smrg         reduction_op = nir_op_fadd;
315b8e80941Smrg         break;
316b8e80941Smrg      case SpvOpGroupNonUniformIMul:
317b8e80941Smrg         reduction_op = nir_op_imul;
318b8e80941Smrg         break;
319b8e80941Smrg      case SpvOpGroupNonUniformFMul:
320b8e80941Smrg         reduction_op = nir_op_fmul;
321b8e80941Smrg         break;
322b8e80941Smrg      case SpvOpGroupNonUniformSMin:
323b8e80941Smrg         reduction_op = nir_op_imin;
324b8e80941Smrg         break;
325b8e80941Smrg      case SpvOpGroupNonUniformUMin:
326b8e80941Smrg         reduction_op = nir_op_umin;
327b8e80941Smrg         break;
328b8e80941Smrg      case SpvOpGroupNonUniformFMin:
329b8e80941Smrg         reduction_op = nir_op_fmin;
330b8e80941Smrg         break;
331b8e80941Smrg      case SpvOpGroupNonUniformSMax:
332b8e80941Smrg         reduction_op = nir_op_imax;
333b8e80941Smrg         break;
334b8e80941Smrg      case SpvOpGroupNonUniformUMax:
335b8e80941Smrg         reduction_op = nir_op_umax;
336b8e80941Smrg         break;
337b8e80941Smrg      case SpvOpGroupNonUniformFMax:
338b8e80941Smrg         reduction_op = nir_op_fmax;
339b8e80941Smrg         break;
340b8e80941Smrg      case SpvOpGroupNonUniformBitwiseAnd:
341b8e80941Smrg      case SpvOpGroupNonUniformLogicalAnd:
342b8e80941Smrg         reduction_op = nir_op_iand;
343b8e80941Smrg         break;
344b8e80941Smrg      case SpvOpGroupNonUniformBitwiseOr:
345b8e80941Smrg      case SpvOpGroupNonUniformLogicalOr:
346b8e80941Smrg         reduction_op = nir_op_ior;
347b8e80941Smrg         break;
348b8e80941Smrg      case SpvOpGroupNonUniformBitwiseXor:
349b8e80941Smrg      case SpvOpGroupNonUniformLogicalXor:
350b8e80941Smrg         reduction_op = nir_op_ixor;
351b8e80941Smrg         break;
352b8e80941Smrg      default:
353b8e80941Smrg         unreachable("Invalid reduction operation");
354b8e80941Smrg      }
355b8e80941Smrg
356b8e80941Smrg      nir_intrinsic_op op;
357b8e80941Smrg      unsigned cluster_size = 0;
358b8e80941Smrg      switch ((SpvGroupOperation)w[4]) {
359b8e80941Smrg      case SpvGroupOperationReduce:
360b8e80941Smrg         op = nir_intrinsic_reduce;
361b8e80941Smrg         break;
362b8e80941Smrg      case SpvGroupOperationInclusiveScan:
363b8e80941Smrg         op = nir_intrinsic_inclusive_scan;
364b8e80941Smrg         break;
365b8e80941Smrg      case SpvGroupOperationExclusiveScan:
366b8e80941Smrg         op = nir_intrinsic_exclusive_scan;
367b8e80941Smrg         break;
368b8e80941Smrg      case SpvGroupOperationClusteredReduce:
369b8e80941Smrg         op = nir_intrinsic_reduce;
370b8e80941Smrg         assert(count == 7);
371b8e80941Smrg         cluster_size = vtn_constant_uint(b, w[6]);
372b8e80941Smrg         break;
373b8e80941Smrg      default:
374b8e80941Smrg         unreachable("Invalid group operation");
375b8e80941Smrg      }
376b8e80941Smrg
377b8e80941Smrg      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
378b8e80941Smrg                               NULL, reduction_op, cluster_size);
379b8e80941Smrg      break;
380b8e80941Smrg   }
381b8e80941Smrg
382b8e80941Smrg   default:
383b8e80941Smrg      unreachable("Invalid SPIR-V opcode");
384b8e80941Smrg   }
385b8e80941Smrg}
386