1b8e80941Smrg/*
2b8e80941Smrg * Copyright © 2016 Intel Corporation
3b8e80941Smrg *
4b8e80941Smrg * Permission is hereby granted, free of charge, to any person obtaining a
5b8e80941Smrg * copy of this software and associated documentation files (the "Software"),
6b8e80941Smrg * to deal in the Software without restriction, including without limitation
7b8e80941Smrg * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8b8e80941Smrg * and/or sell copies of the Software, and to permit persons to whom the
9b8e80941Smrg * Software is furnished to do so, subject to the following conditions:
10b8e80941Smrg *
11b8e80941Smrg * The above copyright notice and this permission notice (including the next
12b8e80941Smrg * paragraph) shall be included in all copies or substantial portions of the
13b8e80941Smrg * Software.
14b8e80941Smrg *
15b8e80941Smrg * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16b8e80941Smrg * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17b8e80941Smrg * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18b8e80941Smrg * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19b8e80941Smrg * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20b8e80941Smrg * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21b8e80941Smrg * IN THE SOFTWARE.
22b8e80941Smrg */
23b8e80941Smrg
24b8e80941Smrg#include <math.h>
25b8e80941Smrg#include "vtn_private.h"
26b8e80941Smrg#include "spirv_info.h"
27b8e80941Smrg
28b8e80941Smrg/*
29b8e80941Smrg * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30b8e80941Smrg * definition. But for matrix multiplies, we want to do one routine for
31b8e80941Smrg * multiplying a matrix by a matrix and then pretend that vectors are matrices
32b8e80941Smrg * with one column. So we "wrap" these things, and unwrap the result before we
33b8e80941Smrg * send it off.
34b8e80941Smrg */
35b8e80941Smrg
36b8e80941Smrgstatic struct vtn_ssa_value *
37b8e80941Smrgwrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38b8e80941Smrg{
39b8e80941Smrg   if (val == NULL)
40b8e80941Smrg      return NULL;
41b8e80941Smrg
42b8e80941Smrg   if (glsl_type_is_matrix(val->type))
43b8e80941Smrg      return val;
44b8e80941Smrg
45b8e80941Smrg   struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46b8e80941Smrg   dest->type = val->type;
47b8e80941Smrg   dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48b8e80941Smrg   dest->elems[0] = val;
49b8e80941Smrg
50b8e80941Smrg   return dest;
51b8e80941Smrg}
52b8e80941Smrg
53b8e80941Smrgstatic struct vtn_ssa_value *
54b8e80941Smrgunwrap_matrix(struct vtn_ssa_value *val)
55b8e80941Smrg{
56b8e80941Smrg   if (glsl_type_is_matrix(val->type))
57b8e80941Smrg         return val;
58b8e80941Smrg
59b8e80941Smrg   return val->elems[0];
60b8e80941Smrg}
61b8e80941Smrg
62b8e80941Smrgstatic struct vtn_ssa_value *
63b8e80941Smrgmatrix_multiply(struct vtn_builder *b,
64b8e80941Smrg                struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65b8e80941Smrg{
66b8e80941Smrg
67b8e80941Smrg   struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68b8e80941Smrg   struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69b8e80941Smrg   struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70b8e80941Smrg   struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71b8e80941Smrg
72b8e80941Smrg   unsigned src0_rows = glsl_get_vector_elements(src0->type);
73b8e80941Smrg   unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74b8e80941Smrg   unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75b8e80941Smrg
76b8e80941Smrg   const struct glsl_type *dest_type;
77b8e80941Smrg   if (src1_columns > 1) {
78b8e80941Smrg      dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79b8e80941Smrg                                   src0_rows, src1_columns);
80b8e80941Smrg   } else {
81b8e80941Smrg      dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82b8e80941Smrg   }
83b8e80941Smrg   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84b8e80941Smrg
85b8e80941Smrg   dest = wrap_matrix(b, dest);
86b8e80941Smrg
87b8e80941Smrg   bool transpose_result = false;
88b8e80941Smrg   if (src0_transpose && src1_transpose) {
89b8e80941Smrg      /* transpose(A) * transpose(B) = transpose(B * A) */
90b8e80941Smrg      src1 = src0_transpose;
91b8e80941Smrg      src0 = src1_transpose;
92b8e80941Smrg      src0_transpose = NULL;
93b8e80941Smrg      src1_transpose = NULL;
94b8e80941Smrg      transpose_result = true;
95b8e80941Smrg   }
96b8e80941Smrg
97b8e80941Smrg   if (src0_transpose && !src1_transpose &&
98b8e80941Smrg       glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99b8e80941Smrg      /* We already have the rows of src0 and the columns of src1 available,
100b8e80941Smrg       * so we can just take the dot product of each row with each column to
101b8e80941Smrg       * get the result.
102b8e80941Smrg       */
103b8e80941Smrg
104b8e80941Smrg      for (unsigned i = 0; i < src1_columns; i++) {
105b8e80941Smrg         nir_ssa_def *vec_src[4];
106b8e80941Smrg         for (unsigned j = 0; j < src0_rows; j++) {
107b8e80941Smrg            vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108b8e80941Smrg                                          src1->elems[i]->def);
109b8e80941Smrg         }
110b8e80941Smrg         dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111b8e80941Smrg      }
112b8e80941Smrg   } else {
113b8e80941Smrg      /* We don't handle the case where src1 is transposed but not src0, since
114b8e80941Smrg       * the general case only uses individual components of src1 so the
115b8e80941Smrg       * optimizer should chew through the transpose we emitted for src1.
116b8e80941Smrg       */
117b8e80941Smrg
118b8e80941Smrg      for (unsigned i = 0; i < src1_columns; i++) {
119b8e80941Smrg         /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120b8e80941Smrg         dest->elems[i]->def =
121b8e80941Smrg            nir_fmul(&b->nb, src0->elems[0]->def,
122b8e80941Smrg                     nir_channel(&b->nb, src1->elems[i]->def, 0));
123b8e80941Smrg         for (unsigned j = 1; j < src0_columns; j++) {
124b8e80941Smrg            dest->elems[i]->def =
125b8e80941Smrg               nir_fadd(&b->nb, dest->elems[i]->def,
126b8e80941Smrg                        nir_fmul(&b->nb, src0->elems[j]->def,
127b8e80941Smrg                                 nir_channel(&b->nb, src1->elems[i]->def, j)));
128b8e80941Smrg         }
129b8e80941Smrg      }
130b8e80941Smrg   }
131b8e80941Smrg
132b8e80941Smrg   dest = unwrap_matrix(dest);
133b8e80941Smrg
134b8e80941Smrg   if (transpose_result)
135b8e80941Smrg      dest = vtn_ssa_transpose(b, dest);
136b8e80941Smrg
137b8e80941Smrg   return dest;
138b8e80941Smrg}
139b8e80941Smrg
140b8e80941Smrgstatic struct vtn_ssa_value *
141b8e80941Smrgmat_times_scalar(struct vtn_builder *b,
142b8e80941Smrg                 struct vtn_ssa_value *mat,
143b8e80941Smrg                 nir_ssa_def *scalar)
144b8e80941Smrg{
145b8e80941Smrg   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146b8e80941Smrg   for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147b8e80941Smrg      if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148b8e80941Smrg         dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149b8e80941Smrg      else
150b8e80941Smrg         dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151b8e80941Smrg   }
152b8e80941Smrg
153b8e80941Smrg   return dest;
154b8e80941Smrg}
155b8e80941Smrg
156b8e80941Smrgstatic void
157b8e80941Smrgvtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158b8e80941Smrg                      struct vtn_value *dest,
159b8e80941Smrg                      struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
160b8e80941Smrg{
161b8e80941Smrg   switch (opcode) {
162b8e80941Smrg   case SpvOpFNegate: {
163b8e80941Smrg      dest->ssa = vtn_create_ssa_value(b, src0->type);
164b8e80941Smrg      unsigned cols = glsl_get_matrix_columns(src0->type);
165b8e80941Smrg      for (unsigned i = 0; i < cols; i++)
166b8e80941Smrg         dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
167b8e80941Smrg      break;
168b8e80941Smrg   }
169b8e80941Smrg
170b8e80941Smrg   case SpvOpFAdd: {
171b8e80941Smrg      dest->ssa = vtn_create_ssa_value(b, src0->type);
172b8e80941Smrg      unsigned cols = glsl_get_matrix_columns(src0->type);
173b8e80941Smrg      for (unsigned i = 0; i < cols; i++)
174b8e80941Smrg         dest->ssa->elems[i]->def =
175b8e80941Smrg            nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
176b8e80941Smrg      break;
177b8e80941Smrg   }
178b8e80941Smrg
179b8e80941Smrg   case SpvOpFSub: {
180b8e80941Smrg      dest->ssa = vtn_create_ssa_value(b, src0->type);
181b8e80941Smrg      unsigned cols = glsl_get_matrix_columns(src0->type);
182b8e80941Smrg      for (unsigned i = 0; i < cols; i++)
183b8e80941Smrg         dest->ssa->elems[i]->def =
184b8e80941Smrg            nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
185b8e80941Smrg      break;
186b8e80941Smrg   }
187b8e80941Smrg
188b8e80941Smrg   case SpvOpTranspose:
189b8e80941Smrg      dest->ssa = vtn_ssa_transpose(b, src0);
190b8e80941Smrg      break;
191b8e80941Smrg
192b8e80941Smrg   case SpvOpMatrixTimesScalar:
193b8e80941Smrg      if (src0->transposed) {
194b8e80941Smrg         dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
195b8e80941Smrg                                                           src1->def));
196b8e80941Smrg      } else {
197b8e80941Smrg         dest->ssa = mat_times_scalar(b, src0, src1->def);
198b8e80941Smrg      }
199b8e80941Smrg      break;
200b8e80941Smrg
201b8e80941Smrg   case SpvOpVectorTimesMatrix:
202b8e80941Smrg   case SpvOpMatrixTimesVector:
203b8e80941Smrg   case SpvOpMatrixTimesMatrix:
204b8e80941Smrg      if (opcode == SpvOpVectorTimesMatrix) {
205b8e80941Smrg         dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
206b8e80941Smrg      } else {
207b8e80941Smrg         dest->ssa = matrix_multiply(b, src0, src1);
208b8e80941Smrg      }
209b8e80941Smrg      break;
210b8e80941Smrg
211b8e80941Smrg   default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
212b8e80941Smrg   }
213b8e80941Smrg}
214b8e80941Smrg
215b8e80941Smrgnir_op
216b8e80941Smrgvtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
217b8e80941Smrg                                SpvOp opcode, bool *swap,
218b8e80941Smrg                                unsigned src_bit_size, unsigned dst_bit_size)
219b8e80941Smrg{
220b8e80941Smrg   /* Indicates that the first two arguments should be swapped.  This is
221b8e80941Smrg    * used for implementing greater-than and less-than-or-equal.
222b8e80941Smrg    */
223b8e80941Smrg   *swap = false;
224b8e80941Smrg
225b8e80941Smrg   switch (opcode) {
226b8e80941Smrg   case SpvOpSNegate:            return nir_op_ineg;
227b8e80941Smrg   case SpvOpFNegate:            return nir_op_fneg;
228b8e80941Smrg   case SpvOpNot:                return nir_op_inot;
229b8e80941Smrg   case SpvOpIAdd:               return nir_op_iadd;
230b8e80941Smrg   case SpvOpFAdd:               return nir_op_fadd;
231b8e80941Smrg   case SpvOpISub:               return nir_op_isub;
232b8e80941Smrg   case SpvOpFSub:               return nir_op_fsub;
233b8e80941Smrg   case SpvOpIMul:               return nir_op_imul;
234b8e80941Smrg   case SpvOpFMul:               return nir_op_fmul;
235b8e80941Smrg   case SpvOpUDiv:               return nir_op_udiv;
236b8e80941Smrg   case SpvOpSDiv:               return nir_op_idiv;
237b8e80941Smrg   case SpvOpFDiv:               return nir_op_fdiv;
238b8e80941Smrg   case SpvOpUMod:               return nir_op_umod;
239b8e80941Smrg   case SpvOpSMod:               return nir_op_imod;
240b8e80941Smrg   case SpvOpFMod:               return nir_op_fmod;
241b8e80941Smrg   case SpvOpSRem:               return nir_op_irem;
242b8e80941Smrg   case SpvOpFRem:               return nir_op_frem;
243b8e80941Smrg
244b8e80941Smrg   case SpvOpShiftRightLogical:     return nir_op_ushr;
245b8e80941Smrg   case SpvOpShiftRightArithmetic:  return nir_op_ishr;
246b8e80941Smrg   case SpvOpShiftLeftLogical:      return nir_op_ishl;
247b8e80941Smrg   case SpvOpLogicalOr:             return nir_op_ior;
248b8e80941Smrg   case SpvOpLogicalEqual:          return nir_op_ieq;
249b8e80941Smrg   case SpvOpLogicalNotEqual:       return nir_op_ine;
250b8e80941Smrg   case SpvOpLogicalAnd:            return nir_op_iand;
251b8e80941Smrg   case SpvOpLogicalNot:            return nir_op_inot;
252b8e80941Smrg   case SpvOpBitwiseOr:             return nir_op_ior;
253b8e80941Smrg   case SpvOpBitwiseXor:            return nir_op_ixor;
254b8e80941Smrg   case SpvOpBitwiseAnd:            return nir_op_iand;
255b8e80941Smrg   case SpvOpSelect:                return nir_op_bcsel;
256b8e80941Smrg   case SpvOpIEqual:                return nir_op_ieq;
257b8e80941Smrg
258b8e80941Smrg   case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
259b8e80941Smrg   case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
260b8e80941Smrg   case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
261b8e80941Smrg   case SpvOpBitReverse:            return nir_op_bitfield_reverse;
262b8e80941Smrg   case SpvOpBitCount:              return nir_op_bit_count;
263b8e80941Smrg
264b8e80941Smrg   /* The ordered / unordered operators need special implementation besides
265b8e80941Smrg    * the logical operator to use since they also need to check if operands are
266b8e80941Smrg    * ordered.
267b8e80941Smrg    */
268b8e80941Smrg   case SpvOpFOrdEqual:                            return nir_op_feq;
269b8e80941Smrg   case SpvOpFUnordEqual:                          return nir_op_feq;
270b8e80941Smrg   case SpvOpINotEqual:                            return nir_op_ine;
271b8e80941Smrg   case SpvOpFOrdNotEqual:                         return nir_op_fne;
272b8e80941Smrg   case SpvOpFUnordNotEqual:                       return nir_op_fne;
273b8e80941Smrg   case SpvOpULessThan:                            return nir_op_ult;
274b8e80941Smrg   case SpvOpSLessThan:                            return nir_op_ilt;
275b8e80941Smrg   case SpvOpFOrdLessThan:                         return nir_op_flt;
276b8e80941Smrg   case SpvOpFUnordLessThan:                       return nir_op_flt;
277b8e80941Smrg   case SpvOpUGreaterThan:          *swap = true;  return nir_op_ult;
278b8e80941Smrg   case SpvOpSGreaterThan:          *swap = true;  return nir_op_ilt;
279b8e80941Smrg   case SpvOpFOrdGreaterThan:       *swap = true;  return nir_op_flt;
280b8e80941Smrg   case SpvOpFUnordGreaterThan:     *swap = true;  return nir_op_flt;
281b8e80941Smrg   case SpvOpULessThanEqual:        *swap = true;  return nir_op_uge;
282b8e80941Smrg   case SpvOpSLessThanEqual:        *swap = true;  return nir_op_ige;
283b8e80941Smrg   case SpvOpFOrdLessThanEqual:     *swap = true;  return nir_op_fge;
284b8e80941Smrg   case SpvOpFUnordLessThanEqual:   *swap = true;  return nir_op_fge;
285b8e80941Smrg   case SpvOpUGreaterThanEqual:                    return nir_op_uge;
286b8e80941Smrg   case SpvOpSGreaterThanEqual:                    return nir_op_ige;
287b8e80941Smrg   case SpvOpFOrdGreaterThanEqual:                 return nir_op_fge;
288b8e80941Smrg   case SpvOpFUnordGreaterThanEqual:               return nir_op_fge;
289b8e80941Smrg
290b8e80941Smrg   /* Conversions: */
291b8e80941Smrg   case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
292b8e80941Smrg   case SpvOpUConvert:
293b8e80941Smrg   case SpvOpConvertFToU:
294b8e80941Smrg   case SpvOpConvertFToS:
295b8e80941Smrg   case SpvOpConvertSToF:
296b8e80941Smrg   case SpvOpConvertUToF:
297b8e80941Smrg   case SpvOpSConvert:
298b8e80941Smrg   case SpvOpFConvert: {
299b8e80941Smrg      nir_alu_type src_type;
300b8e80941Smrg      nir_alu_type dst_type;
301b8e80941Smrg
302b8e80941Smrg      switch (opcode) {
303b8e80941Smrg      case SpvOpConvertFToS:
304b8e80941Smrg         src_type = nir_type_float;
305b8e80941Smrg         dst_type = nir_type_int;
306b8e80941Smrg         break;
307b8e80941Smrg      case SpvOpConvertFToU:
308b8e80941Smrg         src_type = nir_type_float;
309b8e80941Smrg         dst_type = nir_type_uint;
310b8e80941Smrg         break;
311b8e80941Smrg      case SpvOpFConvert:
312b8e80941Smrg         src_type = dst_type = nir_type_float;
313b8e80941Smrg         break;
314b8e80941Smrg      case SpvOpConvertSToF:
315b8e80941Smrg         src_type = nir_type_int;
316b8e80941Smrg         dst_type = nir_type_float;
317b8e80941Smrg         break;
318b8e80941Smrg      case SpvOpSConvert:
319b8e80941Smrg         src_type = dst_type = nir_type_int;
320b8e80941Smrg         break;
321b8e80941Smrg      case SpvOpConvertUToF:
322b8e80941Smrg         src_type = nir_type_uint;
323b8e80941Smrg         dst_type = nir_type_float;
324b8e80941Smrg         break;
325b8e80941Smrg      case SpvOpUConvert:
326b8e80941Smrg         src_type = dst_type = nir_type_uint;
327b8e80941Smrg         break;
328b8e80941Smrg      default:
329b8e80941Smrg         unreachable("Invalid opcode");
330b8e80941Smrg      }
331b8e80941Smrg      src_type |= src_bit_size;
332b8e80941Smrg      dst_type |= dst_bit_size;
333b8e80941Smrg      return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
334b8e80941Smrg   }
335b8e80941Smrg   /* Derivatives: */
336b8e80941Smrg   case SpvOpDPdx:         return nir_op_fddx;
337b8e80941Smrg   case SpvOpDPdy:         return nir_op_fddy;
338b8e80941Smrg   case SpvOpDPdxFine:     return nir_op_fddx_fine;
339b8e80941Smrg   case SpvOpDPdyFine:     return nir_op_fddy_fine;
340b8e80941Smrg   case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
341b8e80941Smrg   case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
342b8e80941Smrg
343b8e80941Smrg   default:
344b8e80941Smrg      vtn_fail("No NIR equivalent: %u", opcode);
345b8e80941Smrg   }
346b8e80941Smrg}
347b8e80941Smrg
348b8e80941Smrgstatic void
349b8e80941Smrghandle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
350b8e80941Smrg                      const struct vtn_decoration *dec, void *_void)
351b8e80941Smrg{
352b8e80941Smrg   vtn_assert(dec->scope == VTN_DEC_DECORATION);
353b8e80941Smrg   if (dec->decoration != SpvDecorationNoContraction)
354b8e80941Smrg      return;
355b8e80941Smrg
356b8e80941Smrg   b->nb.exact = true;
357b8e80941Smrg}
358b8e80941Smrg
359b8e80941Smrgstatic void
360b8e80941Smrghandle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
361b8e80941Smrg                     const struct vtn_decoration *dec, void *_out_rounding_mode)
362b8e80941Smrg{
363b8e80941Smrg   nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
364b8e80941Smrg   assert(dec->scope == VTN_DEC_DECORATION);
365b8e80941Smrg   if (dec->decoration != SpvDecorationFPRoundingMode)
366b8e80941Smrg      return;
367b8e80941Smrg   switch (dec->operands[0]) {
368b8e80941Smrg   case SpvFPRoundingModeRTE:
369b8e80941Smrg      *out_rounding_mode = nir_rounding_mode_rtne;
370b8e80941Smrg      break;
371b8e80941Smrg   case SpvFPRoundingModeRTZ:
372b8e80941Smrg      *out_rounding_mode = nir_rounding_mode_rtz;
373b8e80941Smrg      break;
374b8e80941Smrg   default:
375b8e80941Smrg      unreachable("Not supported rounding mode");
376b8e80941Smrg      break;
377b8e80941Smrg   }
378b8e80941Smrg}
379b8e80941Smrg
380b8e80941Smrgvoid
381b8e80941Smrgvtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
382b8e80941Smrg               const uint32_t *w, unsigned count)
383b8e80941Smrg{
384b8e80941Smrg   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
385b8e80941Smrg   const struct glsl_type *type =
386b8e80941Smrg      vtn_value(b, w[1], vtn_value_type_type)->type->type;
387b8e80941Smrg
388b8e80941Smrg   vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
389b8e80941Smrg
390b8e80941Smrg   /* Collect the various SSA sources */
391b8e80941Smrg   const unsigned num_inputs = count - 3;
392b8e80941Smrg   struct vtn_ssa_value *vtn_src[4] = { NULL, };
393b8e80941Smrg   for (unsigned i = 0; i < num_inputs; i++)
394b8e80941Smrg      vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
395b8e80941Smrg
396b8e80941Smrg   if (glsl_type_is_matrix(vtn_src[0]->type) ||
397b8e80941Smrg       (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
398b8e80941Smrg      vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
399b8e80941Smrg      b->nb.exact = b->exact;
400b8e80941Smrg      return;
401b8e80941Smrg   }
402b8e80941Smrg
403b8e80941Smrg   val->ssa = vtn_create_ssa_value(b, type);
404b8e80941Smrg   nir_ssa_def *src[4] = { NULL, };
405b8e80941Smrg   for (unsigned i = 0; i < num_inputs; i++) {
406b8e80941Smrg      vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
407b8e80941Smrg      src[i] = vtn_src[i]->def;
408b8e80941Smrg   }
409b8e80941Smrg
410b8e80941Smrg   switch (opcode) {
411b8e80941Smrg   case SpvOpAny:
412b8e80941Smrg      if (src[0]->num_components == 1) {
413b8e80941Smrg         val->ssa->def = nir_imov(&b->nb, src[0]);
414b8e80941Smrg      } else {
415b8e80941Smrg         nir_op op;
416b8e80941Smrg         switch (src[0]->num_components) {
417b8e80941Smrg         case 2:  op = nir_op_bany_inequal2; break;
418b8e80941Smrg         case 3:  op = nir_op_bany_inequal3; break;
419b8e80941Smrg         case 4:  op = nir_op_bany_inequal4; break;
420b8e80941Smrg         default: vtn_fail("invalid number of components");
421b8e80941Smrg         }
422b8e80941Smrg         val->ssa->def = nir_build_alu(&b->nb, op, src[0],
423b8e80941Smrg                                       nir_imm_false(&b->nb),
424b8e80941Smrg                                       NULL, NULL);
425b8e80941Smrg      }
426b8e80941Smrg      break;
427b8e80941Smrg
428b8e80941Smrg   case SpvOpAll:
429b8e80941Smrg      if (src[0]->num_components == 1) {
430b8e80941Smrg         val->ssa->def = nir_imov(&b->nb, src[0]);
431b8e80941Smrg      } else {
432b8e80941Smrg         nir_op op;
433b8e80941Smrg         switch (src[0]->num_components) {
434b8e80941Smrg         case 2:  op = nir_op_ball_iequal2;  break;
435b8e80941Smrg         case 3:  op = nir_op_ball_iequal3;  break;
436b8e80941Smrg         case 4:  op = nir_op_ball_iequal4;  break;
437b8e80941Smrg         default: vtn_fail("invalid number of components");
438b8e80941Smrg         }
439b8e80941Smrg         val->ssa->def = nir_build_alu(&b->nb, op, src[0],
440b8e80941Smrg                                       nir_imm_true(&b->nb),
441b8e80941Smrg                                       NULL, NULL);
442b8e80941Smrg      }
443b8e80941Smrg      break;
444b8e80941Smrg
445b8e80941Smrg   case SpvOpOuterProduct: {
446b8e80941Smrg      for (unsigned i = 0; i < src[1]->num_components; i++) {
447b8e80941Smrg         val->ssa->elems[i]->def =
448b8e80941Smrg            nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
449b8e80941Smrg      }
450b8e80941Smrg      break;
451b8e80941Smrg   }
452b8e80941Smrg
453b8e80941Smrg   case SpvOpDot:
454b8e80941Smrg      val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
455b8e80941Smrg      break;
456b8e80941Smrg
457b8e80941Smrg   case SpvOpIAddCarry:
458b8e80941Smrg      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
459b8e80941Smrg      val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
460b8e80941Smrg      val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
461b8e80941Smrg      break;
462b8e80941Smrg
463b8e80941Smrg   case SpvOpISubBorrow:
464b8e80941Smrg      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
465b8e80941Smrg      val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
466b8e80941Smrg      val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
467b8e80941Smrg      break;
468b8e80941Smrg
469b8e80941Smrg   case SpvOpUMulExtended: {
470b8e80941Smrg      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
471b8e80941Smrg      nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
472b8e80941Smrg      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
473b8e80941Smrg      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
474b8e80941Smrg      break;
475b8e80941Smrg   }
476b8e80941Smrg
477b8e80941Smrg   case SpvOpSMulExtended: {
478b8e80941Smrg      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
479b8e80941Smrg      nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
480b8e80941Smrg      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
481b8e80941Smrg      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
482b8e80941Smrg      break;
483b8e80941Smrg   }
484b8e80941Smrg
485b8e80941Smrg   case SpvOpFwidth:
486b8e80941Smrg      val->ssa->def = nir_fadd(&b->nb,
487b8e80941Smrg                               nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
488b8e80941Smrg                               nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
489b8e80941Smrg      break;
490b8e80941Smrg   case SpvOpFwidthFine:
491b8e80941Smrg      val->ssa->def = nir_fadd(&b->nb,
492b8e80941Smrg                               nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
493b8e80941Smrg                               nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
494b8e80941Smrg      break;
495b8e80941Smrg   case SpvOpFwidthCoarse:
496b8e80941Smrg      val->ssa->def = nir_fadd(&b->nb,
497b8e80941Smrg                               nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
498b8e80941Smrg                               nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
499b8e80941Smrg      break;
500b8e80941Smrg
501b8e80941Smrg   case SpvOpVectorTimesScalar:
502b8e80941Smrg      /* The builder will take care of splatting for us. */
503b8e80941Smrg      val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
504b8e80941Smrg      break;
505b8e80941Smrg
506b8e80941Smrg   case SpvOpIsNan:
507b8e80941Smrg      val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
508b8e80941Smrg      break;
509b8e80941Smrg
510b8e80941Smrg   case SpvOpIsInf: {
511b8e80941Smrg      nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
512b8e80941Smrg      val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
513b8e80941Smrg      break;
514b8e80941Smrg   }
515b8e80941Smrg
516b8e80941Smrg   case SpvOpFUnordEqual:
517b8e80941Smrg   case SpvOpFUnordNotEqual:
518b8e80941Smrg   case SpvOpFUnordLessThan:
519b8e80941Smrg   case SpvOpFUnordGreaterThan:
520b8e80941Smrg   case SpvOpFUnordLessThanEqual:
521b8e80941Smrg   case SpvOpFUnordGreaterThanEqual: {
522b8e80941Smrg      bool swap;
523b8e80941Smrg      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
524b8e80941Smrg      unsigned dst_bit_size = glsl_get_bit_size(type);
525b8e80941Smrg      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
526b8e80941Smrg                                                  src_bit_size, dst_bit_size);
527b8e80941Smrg
528b8e80941Smrg      if (swap) {
529b8e80941Smrg         nir_ssa_def *tmp = src[0];
530b8e80941Smrg         src[0] = src[1];
531b8e80941Smrg         src[1] = tmp;
532b8e80941Smrg      }
533b8e80941Smrg
534b8e80941Smrg      val->ssa->def =
535b8e80941Smrg         nir_ior(&b->nb,
536b8e80941Smrg                 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
537b8e80941Smrg                 nir_ior(&b->nb,
538b8e80941Smrg                         nir_fne(&b->nb, src[0], src[0]),
539b8e80941Smrg                         nir_fne(&b->nb, src[1], src[1])));
540b8e80941Smrg      break;
541b8e80941Smrg   }
542b8e80941Smrg
543b8e80941Smrg   case SpvOpFOrdNotEqual: {
544b8e80941Smrg      /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
545b8e80941Smrg       * from the ALU will probably already be false if the operands are not
546b8e80941Smrg       * ordered so we don’t need to handle it specially.
547b8e80941Smrg       */
548b8e80941Smrg      bool swap;
549b8e80941Smrg      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
550b8e80941Smrg      unsigned dst_bit_size = glsl_get_bit_size(type);
551b8e80941Smrg      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
552b8e80941Smrg                                                  src_bit_size, dst_bit_size);
553b8e80941Smrg
554b8e80941Smrg      assert(!swap);
555b8e80941Smrg
556b8e80941Smrg      val->ssa->def =
557b8e80941Smrg         nir_iand(&b->nb,
558b8e80941Smrg                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
559b8e80941Smrg                  nir_iand(&b->nb,
560b8e80941Smrg                          nir_feq(&b->nb, src[0], src[0]),
561b8e80941Smrg                          nir_feq(&b->nb, src[1], src[1])));
562b8e80941Smrg      break;
563b8e80941Smrg   }
564b8e80941Smrg
565b8e80941Smrg   case SpvOpFConvert: {
566b8e80941Smrg      nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
567b8e80941Smrg      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
568b8e80941Smrg      nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
569b8e80941Smrg
570b8e80941Smrg      vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
571b8e80941Smrg      nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
572b8e80941Smrg
573b8e80941Smrg      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
574b8e80941Smrg      break;
575b8e80941Smrg   }
576b8e80941Smrg
577b8e80941Smrg   case SpvOpBitFieldInsert:
578b8e80941Smrg   case SpvOpBitFieldSExtract:
579b8e80941Smrg   case SpvOpBitFieldUExtract:
580b8e80941Smrg   case SpvOpShiftLeftLogical:
581b8e80941Smrg   case SpvOpShiftRightArithmetic:
582b8e80941Smrg   case SpvOpShiftRightLogical: {
583b8e80941Smrg      bool swap;
584b8e80941Smrg      unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
585b8e80941Smrg      unsigned dst_bit_size = glsl_get_bit_size(type);
586b8e80941Smrg      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
587b8e80941Smrg                                                  src0_bit_size, dst_bit_size);
588b8e80941Smrg
589b8e80941Smrg      assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
590b8e80941Smrg              op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
591b8e80941Smrg              op == nir_op_ibitfield_extract);
592b8e80941Smrg
593b8e80941Smrg      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
594b8e80941Smrg         unsigned src_bit_size =
595b8e80941Smrg            nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
596b8e80941Smrg         if (src_bit_size == 0)
597b8e80941Smrg            continue;
598b8e80941Smrg         if (src_bit_size != src[i]->bit_size) {
599b8e80941Smrg            assert(src_bit_size == 32);
600b8e80941Smrg            /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
601b8e80941Smrg             * supported by the NIR instructions. See discussion here:
602b8e80941Smrg             *
603b8e80941Smrg             * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
604b8e80941Smrg             */
605b8e80941Smrg            src[i] = nir_u2u32(&b->nb, src[i]);
606b8e80941Smrg         }
607b8e80941Smrg      }
608b8e80941Smrg      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
609b8e80941Smrg      break;
610b8e80941Smrg   }
611b8e80941Smrg
612b8e80941Smrg   case SpvOpSignBitSet: {
613b8e80941Smrg      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
614b8e80941Smrg      if (src[0]->num_components == 1)
615b8e80941Smrg         val->ssa->def =
616b8e80941Smrg            nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
617b8e80941Smrg      else
618b8e80941Smrg         val->ssa->def =
619b8e80941Smrg            nir_ishr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
620b8e80941Smrg
621b8e80941Smrg      val->ssa->def = nir_i2b(&b->nb, val->ssa->def);
622b8e80941Smrg      break;
623b8e80941Smrg   }
624b8e80941Smrg
625b8e80941Smrg   default: {
626b8e80941Smrg      bool swap;
627b8e80941Smrg      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
628b8e80941Smrg      unsigned dst_bit_size = glsl_get_bit_size(type);
629b8e80941Smrg      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
630b8e80941Smrg                                                  src_bit_size, dst_bit_size);
631b8e80941Smrg
632b8e80941Smrg      if (swap) {
633b8e80941Smrg         nir_ssa_def *tmp = src[0];
634b8e80941Smrg         src[0] = src[1];
635b8e80941Smrg         src[1] = tmp;
636b8e80941Smrg      }
637b8e80941Smrg
638b8e80941Smrg      switch (op) {
639b8e80941Smrg      case nir_op_ishl:
640b8e80941Smrg      case nir_op_ishr:
641b8e80941Smrg      case nir_op_ushr:
642b8e80941Smrg         if (src[1]->bit_size != 32)
643b8e80941Smrg            src[1] = nir_u2u32(&b->nb, src[1]);
644b8e80941Smrg         break;
645b8e80941Smrg      default:
646b8e80941Smrg         break;
647b8e80941Smrg      }
648b8e80941Smrg
649b8e80941Smrg      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
650b8e80941Smrg      break;
651b8e80941Smrg   } /* default */
652b8e80941Smrg   }
653b8e80941Smrg
654b8e80941Smrg   b->nb.exact = b->exact;
655b8e80941Smrg}
656b8e80941Smrg
657b8e80941Smrgvoid
658b8e80941Smrgvtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
659b8e80941Smrg{
660b8e80941Smrg   vtn_assert(count == 4);
661b8e80941Smrg   /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
662b8e80941Smrg    *
663b8e80941Smrg    *    "If Result Type has the same number of components as Operand, they
664b8e80941Smrg    *    must also have the same component width, and results are computed per
665b8e80941Smrg    *    component.
666b8e80941Smrg    *
667b8e80941Smrg    *    If Result Type has a different number of components than Operand, the
668b8e80941Smrg    *    total number of bits in Result Type must equal the total number of
669b8e80941Smrg    *    bits in Operand. Let L be the type, either Result Type or Operand’s
670b8e80941Smrg    *    type, that has the larger number of components. Let S be the other
671b8e80941Smrg    *    type, with the smaller number of components. The number of components
672b8e80941Smrg    *    in L must be an integer multiple of the number of components in S.
673b8e80941Smrg    *    The first component (that is, the only or lowest-numbered component)
674b8e80941Smrg    *    of S maps to the first components of L, and so on, up to the last
675b8e80941Smrg    *    component of S mapping to the last components of L. Within this
676b8e80941Smrg    *    mapping, any single component of S (mapping to multiple components of
677b8e80941Smrg    *    L) maps its lower-ordered bits to the lower-numbered components of L."
678b8e80941Smrg    */
679b8e80941Smrg
680b8e80941Smrg   struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
681b8e80941Smrg   struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]);
682b8e80941Smrg   struct nir_ssa_def *src = vtn_src->def;
683b8e80941Smrg   struct vtn_ssa_value *val = vtn_create_ssa_value(b, type->type);
684b8e80941Smrg
685b8e80941Smrg   vtn_assert(glsl_type_is_vector_or_scalar(vtn_src->type));
686b8e80941Smrg
687b8e80941Smrg   vtn_fail_if(src->num_components * src->bit_size !=
688b8e80941Smrg               glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
689b8e80941Smrg               "Source and destination of OpBitcast must have the same "
690b8e80941Smrg               "total number of bits");
691b8e80941Smrg   val->def = nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
692b8e80941Smrg   vtn_push_ssa(b, w[2], type, val);
693b8e80941Smrg}
694