1b8e80941Smrg/* 2b8e80941Smrg * Copyright © 2016 Intel Corporation 3b8e80941Smrg * 4b8e80941Smrg * Permission is hereby granted, free of charge, to any person obtaining a 5b8e80941Smrg * copy of this software and associated documentation files (the "Software"), 6b8e80941Smrg * to deal in the Software without restriction, including without limitation 7b8e80941Smrg * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8b8e80941Smrg * and/or sell copies of the Software, and to permit persons to whom the 9b8e80941Smrg * Software is furnished to do so, subject to the following conditions: 10b8e80941Smrg * 11b8e80941Smrg * The above copyright notice and this permission notice (including the next 12b8e80941Smrg * paragraph) shall be included in all copies or substantial portions of the 13b8e80941Smrg * Software. 14b8e80941Smrg * 15b8e80941Smrg * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16b8e80941Smrg * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17b8e80941Smrg * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18b8e80941Smrg * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19b8e80941Smrg * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20b8e80941Smrg * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21b8e80941Smrg * IN THE SOFTWARE. 22b8e80941Smrg */ 23b8e80941Smrg 24b8e80941Smrg#include <math.h> 25b8e80941Smrg#include "vtn_private.h" 26b8e80941Smrg#include "spirv_info.h" 27b8e80941Smrg 28b8e80941Smrg/* 29b8e80941Smrg * Normally, column vectors in SPIR-V correspond to a single NIR SSA 30b8e80941Smrg * definition. But for matrix multiplies, we want to do one routine for 31b8e80941Smrg * multiplying a matrix by a matrix and then pretend that vectors are matrices 32b8e80941Smrg * with one column. So we "wrap" these things, and unwrap the result before we 33b8e80941Smrg * send it off. 34b8e80941Smrg */ 35b8e80941Smrg 36b8e80941Smrgstatic struct vtn_ssa_value * 37b8e80941Smrgwrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) 38b8e80941Smrg{ 39b8e80941Smrg if (val == NULL) 40b8e80941Smrg return NULL; 41b8e80941Smrg 42b8e80941Smrg if (glsl_type_is_matrix(val->type)) 43b8e80941Smrg return val; 44b8e80941Smrg 45b8e80941Smrg struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); 46b8e80941Smrg dest->type = val->type; 47b8e80941Smrg dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); 48b8e80941Smrg dest->elems[0] = val; 49b8e80941Smrg 50b8e80941Smrg return dest; 51b8e80941Smrg} 52b8e80941Smrg 53b8e80941Smrgstatic struct vtn_ssa_value * 54b8e80941Smrgunwrap_matrix(struct vtn_ssa_value *val) 55b8e80941Smrg{ 56b8e80941Smrg if (glsl_type_is_matrix(val->type)) 57b8e80941Smrg return val; 58b8e80941Smrg 59b8e80941Smrg return val->elems[0]; 60b8e80941Smrg} 61b8e80941Smrg 62b8e80941Smrgstatic struct vtn_ssa_value * 63b8e80941Smrgmatrix_multiply(struct vtn_builder *b, 64b8e80941Smrg struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1) 65b8e80941Smrg{ 66b8e80941Smrg 67b8e80941Smrg struct vtn_ssa_value *src0 = wrap_matrix(b, _src0); 68b8e80941Smrg struct vtn_ssa_value *src1 = wrap_matrix(b, _src1); 69b8e80941Smrg struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed); 70b8e80941Smrg struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed); 71b8e80941Smrg 72b8e80941Smrg unsigned src0_rows = glsl_get_vector_elements(src0->type); 73b8e80941Smrg unsigned src0_columns = glsl_get_matrix_columns(src0->type); 74b8e80941Smrg unsigned src1_columns = glsl_get_matrix_columns(src1->type); 75b8e80941Smrg 76b8e80941Smrg const struct glsl_type *dest_type; 77b8e80941Smrg if (src1_columns > 1) { 78b8e80941Smrg dest_type = glsl_matrix_type(glsl_get_base_type(src0->type), 79b8e80941Smrg src0_rows, src1_columns); 80b8e80941Smrg } else { 81b8e80941Smrg dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows); 82b8e80941Smrg } 83b8e80941Smrg struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); 84b8e80941Smrg 85b8e80941Smrg dest = wrap_matrix(b, dest); 86b8e80941Smrg 87b8e80941Smrg bool transpose_result = false; 88b8e80941Smrg if (src0_transpose && src1_transpose) { 89b8e80941Smrg /* transpose(A) * transpose(B) = transpose(B * A) */ 90b8e80941Smrg src1 = src0_transpose; 91b8e80941Smrg src0 = src1_transpose; 92b8e80941Smrg src0_transpose = NULL; 93b8e80941Smrg src1_transpose = NULL; 94b8e80941Smrg transpose_result = true; 95b8e80941Smrg } 96b8e80941Smrg 97b8e80941Smrg if (src0_transpose && !src1_transpose && 98b8e80941Smrg glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { 99b8e80941Smrg /* We already have the rows of src0 and the columns of src1 available, 100b8e80941Smrg * so we can just take the dot product of each row with each column to 101b8e80941Smrg * get the result. 102b8e80941Smrg */ 103b8e80941Smrg 104b8e80941Smrg for (unsigned i = 0; i < src1_columns; i++) { 105b8e80941Smrg nir_ssa_def *vec_src[4]; 106b8e80941Smrg for (unsigned j = 0; j < src0_rows; j++) { 107b8e80941Smrg vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def, 108b8e80941Smrg src1->elems[i]->def); 109b8e80941Smrg } 110b8e80941Smrg dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows); 111b8e80941Smrg } 112b8e80941Smrg } else { 113b8e80941Smrg /* We don't handle the case where src1 is transposed but not src0, since 114b8e80941Smrg * the general case only uses individual components of src1 so the 115b8e80941Smrg * optimizer should chew through the transpose we emitted for src1. 116b8e80941Smrg */ 117b8e80941Smrg 118b8e80941Smrg for (unsigned i = 0; i < src1_columns; i++) { 119b8e80941Smrg /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ 120b8e80941Smrg dest->elems[i]->def = 121b8e80941Smrg nir_fmul(&b->nb, src0->elems[0]->def, 122b8e80941Smrg nir_channel(&b->nb, src1->elems[i]->def, 0)); 123b8e80941Smrg for (unsigned j = 1; j < src0_columns; j++) { 124b8e80941Smrg dest->elems[i]->def = 125b8e80941Smrg nir_fadd(&b->nb, dest->elems[i]->def, 126b8e80941Smrg nir_fmul(&b->nb, src0->elems[j]->def, 127b8e80941Smrg nir_channel(&b->nb, src1->elems[i]->def, j))); 128b8e80941Smrg } 129b8e80941Smrg } 130b8e80941Smrg } 131b8e80941Smrg 132b8e80941Smrg dest = unwrap_matrix(dest); 133b8e80941Smrg 134b8e80941Smrg if (transpose_result) 135b8e80941Smrg dest = vtn_ssa_transpose(b, dest); 136b8e80941Smrg 137b8e80941Smrg return dest; 138b8e80941Smrg} 139b8e80941Smrg 140b8e80941Smrgstatic struct vtn_ssa_value * 141b8e80941Smrgmat_times_scalar(struct vtn_builder *b, 142b8e80941Smrg struct vtn_ssa_value *mat, 143b8e80941Smrg nir_ssa_def *scalar) 144b8e80941Smrg{ 145b8e80941Smrg struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); 146b8e80941Smrg for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { 147b8e80941Smrg if (glsl_base_type_is_integer(glsl_get_base_type(mat->type))) 148b8e80941Smrg dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); 149b8e80941Smrg else 150b8e80941Smrg dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); 151b8e80941Smrg } 152b8e80941Smrg 153b8e80941Smrg return dest; 154b8e80941Smrg} 155b8e80941Smrg 156b8e80941Smrgstatic void 157b8e80941Smrgvtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, 158b8e80941Smrg struct vtn_value *dest, 159b8e80941Smrg struct vtn_ssa_value *src0, struct vtn_ssa_value *src1) 160b8e80941Smrg{ 161b8e80941Smrg switch (opcode) { 162b8e80941Smrg case SpvOpFNegate: { 163b8e80941Smrg dest->ssa = vtn_create_ssa_value(b, src0->type); 164b8e80941Smrg unsigned cols = glsl_get_matrix_columns(src0->type); 165b8e80941Smrg for (unsigned i = 0; i < cols; i++) 166b8e80941Smrg dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); 167b8e80941Smrg break; 168b8e80941Smrg } 169b8e80941Smrg 170b8e80941Smrg case SpvOpFAdd: { 171b8e80941Smrg dest->ssa = vtn_create_ssa_value(b, src0->type); 172b8e80941Smrg unsigned cols = glsl_get_matrix_columns(src0->type); 173b8e80941Smrg for (unsigned i = 0; i < cols; i++) 174b8e80941Smrg dest->ssa->elems[i]->def = 175b8e80941Smrg nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def); 176b8e80941Smrg break; 177b8e80941Smrg } 178b8e80941Smrg 179b8e80941Smrg case SpvOpFSub: { 180b8e80941Smrg dest->ssa = vtn_create_ssa_value(b, src0->type); 181b8e80941Smrg unsigned cols = glsl_get_matrix_columns(src0->type); 182b8e80941Smrg for (unsigned i = 0; i < cols; i++) 183b8e80941Smrg dest->ssa->elems[i]->def = 184b8e80941Smrg nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def); 185b8e80941Smrg break; 186b8e80941Smrg } 187b8e80941Smrg 188b8e80941Smrg case SpvOpTranspose: 189b8e80941Smrg dest->ssa = vtn_ssa_transpose(b, src0); 190b8e80941Smrg break; 191b8e80941Smrg 192b8e80941Smrg case SpvOpMatrixTimesScalar: 193b8e80941Smrg if (src0->transposed) { 194b8e80941Smrg dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, 195b8e80941Smrg src1->def)); 196b8e80941Smrg } else { 197b8e80941Smrg dest->ssa = mat_times_scalar(b, src0, src1->def); 198b8e80941Smrg } 199b8e80941Smrg break; 200b8e80941Smrg 201b8e80941Smrg case SpvOpVectorTimesMatrix: 202b8e80941Smrg case SpvOpMatrixTimesVector: 203b8e80941Smrg case SpvOpMatrixTimesMatrix: 204b8e80941Smrg if (opcode == SpvOpVectorTimesMatrix) { 205b8e80941Smrg dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); 206b8e80941Smrg } else { 207b8e80941Smrg dest->ssa = matrix_multiply(b, src0, src1); 208b8e80941Smrg } 209b8e80941Smrg break; 210b8e80941Smrg 211b8e80941Smrg default: vtn_fail_with_opcode("unknown matrix opcode", opcode); 212b8e80941Smrg } 213b8e80941Smrg} 214b8e80941Smrg 215b8e80941Smrgnir_op 216b8e80941Smrgvtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, 217b8e80941Smrg SpvOp opcode, bool *swap, 218b8e80941Smrg unsigned src_bit_size, unsigned dst_bit_size) 219b8e80941Smrg{ 220b8e80941Smrg /* Indicates that the first two arguments should be swapped. This is 221b8e80941Smrg * used for implementing greater-than and less-than-or-equal. 222b8e80941Smrg */ 223b8e80941Smrg *swap = false; 224b8e80941Smrg 225b8e80941Smrg switch (opcode) { 226b8e80941Smrg case SpvOpSNegate: return nir_op_ineg; 227b8e80941Smrg case SpvOpFNegate: return nir_op_fneg; 228b8e80941Smrg case SpvOpNot: return nir_op_inot; 229b8e80941Smrg case SpvOpIAdd: return nir_op_iadd; 230b8e80941Smrg case SpvOpFAdd: return nir_op_fadd; 231b8e80941Smrg case SpvOpISub: return nir_op_isub; 232b8e80941Smrg case SpvOpFSub: return nir_op_fsub; 233b8e80941Smrg case SpvOpIMul: return nir_op_imul; 234b8e80941Smrg case SpvOpFMul: return nir_op_fmul; 235b8e80941Smrg case SpvOpUDiv: return nir_op_udiv; 236b8e80941Smrg case SpvOpSDiv: return nir_op_idiv; 237b8e80941Smrg case SpvOpFDiv: return nir_op_fdiv; 238b8e80941Smrg case SpvOpUMod: return nir_op_umod; 239b8e80941Smrg case SpvOpSMod: return nir_op_imod; 240b8e80941Smrg case SpvOpFMod: return nir_op_fmod; 241b8e80941Smrg case SpvOpSRem: return nir_op_irem; 242b8e80941Smrg case SpvOpFRem: return nir_op_frem; 243b8e80941Smrg 244b8e80941Smrg case SpvOpShiftRightLogical: return nir_op_ushr; 245b8e80941Smrg case SpvOpShiftRightArithmetic: return nir_op_ishr; 246b8e80941Smrg case SpvOpShiftLeftLogical: return nir_op_ishl; 247b8e80941Smrg case SpvOpLogicalOr: return nir_op_ior; 248b8e80941Smrg case SpvOpLogicalEqual: return nir_op_ieq; 249b8e80941Smrg case SpvOpLogicalNotEqual: return nir_op_ine; 250b8e80941Smrg case SpvOpLogicalAnd: return nir_op_iand; 251b8e80941Smrg case SpvOpLogicalNot: return nir_op_inot; 252b8e80941Smrg case SpvOpBitwiseOr: return nir_op_ior; 253b8e80941Smrg case SpvOpBitwiseXor: return nir_op_ixor; 254b8e80941Smrg case SpvOpBitwiseAnd: return nir_op_iand; 255b8e80941Smrg case SpvOpSelect: return nir_op_bcsel; 256b8e80941Smrg case SpvOpIEqual: return nir_op_ieq; 257b8e80941Smrg 258b8e80941Smrg case SpvOpBitFieldInsert: return nir_op_bitfield_insert; 259b8e80941Smrg case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract; 260b8e80941Smrg case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract; 261b8e80941Smrg case SpvOpBitReverse: return nir_op_bitfield_reverse; 262b8e80941Smrg case SpvOpBitCount: return nir_op_bit_count; 263b8e80941Smrg 264b8e80941Smrg /* The ordered / unordered operators need special implementation besides 265b8e80941Smrg * the logical operator to use since they also need to check if operands are 266b8e80941Smrg * ordered. 267b8e80941Smrg */ 268b8e80941Smrg case SpvOpFOrdEqual: return nir_op_feq; 269b8e80941Smrg case SpvOpFUnordEqual: return nir_op_feq; 270b8e80941Smrg case SpvOpINotEqual: return nir_op_ine; 271b8e80941Smrg case SpvOpFOrdNotEqual: return nir_op_fne; 272b8e80941Smrg case SpvOpFUnordNotEqual: return nir_op_fne; 273b8e80941Smrg case SpvOpULessThan: return nir_op_ult; 274b8e80941Smrg case SpvOpSLessThan: return nir_op_ilt; 275b8e80941Smrg case SpvOpFOrdLessThan: return nir_op_flt; 276b8e80941Smrg case SpvOpFUnordLessThan: return nir_op_flt; 277b8e80941Smrg case SpvOpUGreaterThan: *swap = true; return nir_op_ult; 278b8e80941Smrg case SpvOpSGreaterThan: *swap = true; return nir_op_ilt; 279b8e80941Smrg case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt; 280b8e80941Smrg case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt; 281b8e80941Smrg case SpvOpULessThanEqual: *swap = true; return nir_op_uge; 282b8e80941Smrg case SpvOpSLessThanEqual: *swap = true; return nir_op_ige; 283b8e80941Smrg case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge; 284b8e80941Smrg case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge; 285b8e80941Smrg case SpvOpUGreaterThanEqual: return nir_op_uge; 286b8e80941Smrg case SpvOpSGreaterThanEqual: return nir_op_ige; 287b8e80941Smrg case SpvOpFOrdGreaterThanEqual: return nir_op_fge; 288b8e80941Smrg case SpvOpFUnordGreaterThanEqual: return nir_op_fge; 289b8e80941Smrg 290b8e80941Smrg /* Conversions: */ 291b8e80941Smrg case SpvOpQuantizeToF16: return nir_op_fquantize2f16; 292b8e80941Smrg case SpvOpUConvert: 293b8e80941Smrg case SpvOpConvertFToU: 294b8e80941Smrg case SpvOpConvertFToS: 295b8e80941Smrg case SpvOpConvertSToF: 296b8e80941Smrg case SpvOpConvertUToF: 297b8e80941Smrg case SpvOpSConvert: 298b8e80941Smrg case SpvOpFConvert: { 299b8e80941Smrg nir_alu_type src_type; 300b8e80941Smrg nir_alu_type dst_type; 301b8e80941Smrg 302b8e80941Smrg switch (opcode) { 303b8e80941Smrg case SpvOpConvertFToS: 304b8e80941Smrg src_type = nir_type_float; 305b8e80941Smrg dst_type = nir_type_int; 306b8e80941Smrg break; 307b8e80941Smrg case SpvOpConvertFToU: 308b8e80941Smrg src_type = nir_type_float; 309b8e80941Smrg dst_type = nir_type_uint; 310b8e80941Smrg break; 311b8e80941Smrg case SpvOpFConvert: 312b8e80941Smrg src_type = dst_type = nir_type_float; 313b8e80941Smrg break; 314b8e80941Smrg case SpvOpConvertSToF: 315b8e80941Smrg src_type = nir_type_int; 316b8e80941Smrg dst_type = nir_type_float; 317b8e80941Smrg break; 318b8e80941Smrg case SpvOpSConvert: 319b8e80941Smrg src_type = dst_type = nir_type_int; 320b8e80941Smrg break; 321b8e80941Smrg case SpvOpConvertUToF: 322b8e80941Smrg src_type = nir_type_uint; 323b8e80941Smrg dst_type = nir_type_float; 324b8e80941Smrg break; 325b8e80941Smrg case SpvOpUConvert: 326b8e80941Smrg src_type = dst_type = nir_type_uint; 327b8e80941Smrg break; 328b8e80941Smrg default: 329b8e80941Smrg unreachable("Invalid opcode"); 330b8e80941Smrg } 331b8e80941Smrg src_type |= src_bit_size; 332b8e80941Smrg dst_type |= dst_bit_size; 333b8e80941Smrg return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef); 334b8e80941Smrg } 335b8e80941Smrg /* Derivatives: */ 336b8e80941Smrg case SpvOpDPdx: return nir_op_fddx; 337b8e80941Smrg case SpvOpDPdy: return nir_op_fddy; 338b8e80941Smrg case SpvOpDPdxFine: return nir_op_fddx_fine; 339b8e80941Smrg case SpvOpDPdyFine: return nir_op_fddy_fine; 340b8e80941Smrg case SpvOpDPdxCoarse: return nir_op_fddx_coarse; 341b8e80941Smrg case SpvOpDPdyCoarse: return nir_op_fddy_coarse; 342b8e80941Smrg 343b8e80941Smrg default: 344b8e80941Smrg vtn_fail("No NIR equivalent: %u", opcode); 345b8e80941Smrg } 346b8e80941Smrg} 347b8e80941Smrg 348b8e80941Smrgstatic void 349b8e80941Smrghandle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member, 350b8e80941Smrg const struct vtn_decoration *dec, void *_void) 351b8e80941Smrg{ 352b8e80941Smrg vtn_assert(dec->scope == VTN_DEC_DECORATION); 353b8e80941Smrg if (dec->decoration != SpvDecorationNoContraction) 354b8e80941Smrg return; 355b8e80941Smrg 356b8e80941Smrg b->nb.exact = true; 357b8e80941Smrg} 358b8e80941Smrg 359b8e80941Smrgstatic void 360b8e80941Smrghandle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member, 361b8e80941Smrg const struct vtn_decoration *dec, void *_out_rounding_mode) 362b8e80941Smrg{ 363b8e80941Smrg nir_rounding_mode *out_rounding_mode = _out_rounding_mode; 364b8e80941Smrg assert(dec->scope == VTN_DEC_DECORATION); 365b8e80941Smrg if (dec->decoration != SpvDecorationFPRoundingMode) 366b8e80941Smrg return; 367b8e80941Smrg switch (dec->operands[0]) { 368b8e80941Smrg case SpvFPRoundingModeRTE: 369b8e80941Smrg *out_rounding_mode = nir_rounding_mode_rtne; 370b8e80941Smrg break; 371b8e80941Smrg case SpvFPRoundingModeRTZ: 372b8e80941Smrg *out_rounding_mode = nir_rounding_mode_rtz; 373b8e80941Smrg break; 374b8e80941Smrg default: 375b8e80941Smrg unreachable("Not supported rounding mode"); 376b8e80941Smrg break; 377b8e80941Smrg } 378b8e80941Smrg} 379b8e80941Smrg 380b8e80941Smrgvoid 381b8e80941Smrgvtn_handle_alu(struct vtn_builder *b, SpvOp opcode, 382b8e80941Smrg const uint32_t *w, unsigned count) 383b8e80941Smrg{ 384b8e80941Smrg struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); 385b8e80941Smrg const struct glsl_type *type = 386b8e80941Smrg vtn_value(b, w[1], vtn_value_type_type)->type->type; 387b8e80941Smrg 388b8e80941Smrg vtn_foreach_decoration(b, val, handle_no_contraction, NULL); 389b8e80941Smrg 390b8e80941Smrg /* Collect the various SSA sources */ 391b8e80941Smrg const unsigned num_inputs = count - 3; 392b8e80941Smrg struct vtn_ssa_value *vtn_src[4] = { NULL, }; 393b8e80941Smrg for (unsigned i = 0; i < num_inputs; i++) 394b8e80941Smrg vtn_src[i] = vtn_ssa_value(b, w[i + 3]); 395b8e80941Smrg 396b8e80941Smrg if (glsl_type_is_matrix(vtn_src[0]->type) || 397b8e80941Smrg (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { 398b8e80941Smrg vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]); 399b8e80941Smrg b->nb.exact = b->exact; 400b8e80941Smrg return; 401b8e80941Smrg } 402b8e80941Smrg 403b8e80941Smrg val->ssa = vtn_create_ssa_value(b, type); 404b8e80941Smrg nir_ssa_def *src[4] = { NULL, }; 405b8e80941Smrg for (unsigned i = 0; i < num_inputs; i++) { 406b8e80941Smrg vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); 407b8e80941Smrg src[i] = vtn_src[i]->def; 408b8e80941Smrg } 409b8e80941Smrg 410b8e80941Smrg switch (opcode) { 411b8e80941Smrg case SpvOpAny: 412b8e80941Smrg if (src[0]->num_components == 1) { 413b8e80941Smrg val->ssa->def = nir_imov(&b->nb, src[0]); 414b8e80941Smrg } else { 415b8e80941Smrg nir_op op; 416b8e80941Smrg switch (src[0]->num_components) { 417b8e80941Smrg case 2: op = nir_op_bany_inequal2; break; 418b8e80941Smrg case 3: op = nir_op_bany_inequal3; break; 419b8e80941Smrg case 4: op = nir_op_bany_inequal4; break; 420b8e80941Smrg default: vtn_fail("invalid number of components"); 421b8e80941Smrg } 422b8e80941Smrg val->ssa->def = nir_build_alu(&b->nb, op, src[0], 423b8e80941Smrg nir_imm_false(&b->nb), 424b8e80941Smrg NULL, NULL); 425b8e80941Smrg } 426b8e80941Smrg break; 427b8e80941Smrg 428b8e80941Smrg case SpvOpAll: 429b8e80941Smrg if (src[0]->num_components == 1) { 430b8e80941Smrg val->ssa->def = nir_imov(&b->nb, src[0]); 431b8e80941Smrg } else { 432b8e80941Smrg nir_op op; 433b8e80941Smrg switch (src[0]->num_components) { 434b8e80941Smrg case 2: op = nir_op_ball_iequal2; break; 435b8e80941Smrg case 3: op = nir_op_ball_iequal3; break; 436b8e80941Smrg case 4: op = nir_op_ball_iequal4; break; 437b8e80941Smrg default: vtn_fail("invalid number of components"); 438b8e80941Smrg } 439b8e80941Smrg val->ssa->def = nir_build_alu(&b->nb, op, src[0], 440b8e80941Smrg nir_imm_true(&b->nb), 441b8e80941Smrg NULL, NULL); 442b8e80941Smrg } 443b8e80941Smrg break; 444b8e80941Smrg 445b8e80941Smrg case SpvOpOuterProduct: { 446b8e80941Smrg for (unsigned i = 0; i < src[1]->num_components; i++) { 447b8e80941Smrg val->ssa->elems[i]->def = 448b8e80941Smrg nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i)); 449b8e80941Smrg } 450b8e80941Smrg break; 451b8e80941Smrg } 452b8e80941Smrg 453b8e80941Smrg case SpvOpDot: 454b8e80941Smrg val->ssa->def = nir_fdot(&b->nb, src[0], src[1]); 455b8e80941Smrg break; 456b8e80941Smrg 457b8e80941Smrg case SpvOpIAddCarry: 458b8e80941Smrg vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type)); 459b8e80941Smrg val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); 460b8e80941Smrg val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); 461b8e80941Smrg break; 462b8e80941Smrg 463b8e80941Smrg case SpvOpISubBorrow: 464b8e80941Smrg vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type)); 465b8e80941Smrg val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); 466b8e80941Smrg val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); 467b8e80941Smrg break; 468b8e80941Smrg 469b8e80941Smrg case SpvOpUMulExtended: { 470b8e80941Smrg vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type)); 471b8e80941Smrg nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]); 472b8e80941Smrg val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul); 473b8e80941Smrg val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul); 474b8e80941Smrg break; 475b8e80941Smrg } 476b8e80941Smrg 477b8e80941Smrg case SpvOpSMulExtended: { 478b8e80941Smrg vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type)); 479b8e80941Smrg nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]); 480b8e80941Smrg val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul); 481b8e80941Smrg val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul); 482b8e80941Smrg break; 483b8e80941Smrg } 484b8e80941Smrg 485b8e80941Smrg case SpvOpFwidth: 486b8e80941Smrg val->ssa->def = nir_fadd(&b->nb, 487b8e80941Smrg nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), 488b8e80941Smrg nir_fabs(&b->nb, nir_fddy(&b->nb, src[0]))); 489b8e80941Smrg break; 490b8e80941Smrg case SpvOpFwidthFine: 491b8e80941Smrg val->ssa->def = nir_fadd(&b->nb, 492b8e80941Smrg nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), 493b8e80941Smrg nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0]))); 494b8e80941Smrg break; 495b8e80941Smrg case SpvOpFwidthCoarse: 496b8e80941Smrg val->ssa->def = nir_fadd(&b->nb, 497b8e80941Smrg nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), 498b8e80941Smrg nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0]))); 499b8e80941Smrg break; 500b8e80941Smrg 501b8e80941Smrg case SpvOpVectorTimesScalar: 502b8e80941Smrg /* The builder will take care of splatting for us. */ 503b8e80941Smrg val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); 504b8e80941Smrg break; 505b8e80941Smrg 506b8e80941Smrg case SpvOpIsNan: 507b8e80941Smrg val->ssa->def = nir_fne(&b->nb, src[0], src[0]); 508b8e80941Smrg break; 509b8e80941Smrg 510b8e80941Smrg case SpvOpIsInf: { 511b8e80941Smrg nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size); 512b8e80941Smrg val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf); 513b8e80941Smrg break; 514b8e80941Smrg } 515b8e80941Smrg 516b8e80941Smrg case SpvOpFUnordEqual: 517b8e80941Smrg case SpvOpFUnordNotEqual: 518b8e80941Smrg case SpvOpFUnordLessThan: 519b8e80941Smrg case SpvOpFUnordGreaterThan: 520b8e80941Smrg case SpvOpFUnordLessThanEqual: 521b8e80941Smrg case SpvOpFUnordGreaterThanEqual: { 522b8e80941Smrg bool swap; 523b8e80941Smrg unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); 524b8e80941Smrg unsigned dst_bit_size = glsl_get_bit_size(type); 525b8e80941Smrg nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, 526b8e80941Smrg src_bit_size, dst_bit_size); 527b8e80941Smrg 528b8e80941Smrg if (swap) { 529b8e80941Smrg nir_ssa_def *tmp = src[0]; 530b8e80941Smrg src[0] = src[1]; 531b8e80941Smrg src[1] = tmp; 532b8e80941Smrg } 533b8e80941Smrg 534b8e80941Smrg val->ssa->def = 535b8e80941Smrg nir_ior(&b->nb, 536b8e80941Smrg nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), 537b8e80941Smrg nir_ior(&b->nb, 538b8e80941Smrg nir_fne(&b->nb, src[0], src[0]), 539b8e80941Smrg nir_fne(&b->nb, src[1], src[1]))); 540b8e80941Smrg break; 541b8e80941Smrg } 542b8e80941Smrg 543b8e80941Smrg case SpvOpFOrdNotEqual: { 544b8e80941Smrg /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value 545b8e80941Smrg * from the ALU will probably already be false if the operands are not 546b8e80941Smrg * ordered so we don’t need to handle it specially. 547b8e80941Smrg */ 548b8e80941Smrg bool swap; 549b8e80941Smrg unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); 550b8e80941Smrg unsigned dst_bit_size = glsl_get_bit_size(type); 551b8e80941Smrg nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, 552b8e80941Smrg src_bit_size, dst_bit_size); 553b8e80941Smrg 554b8e80941Smrg assert(!swap); 555b8e80941Smrg 556b8e80941Smrg val->ssa->def = 557b8e80941Smrg nir_iand(&b->nb, 558b8e80941Smrg nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), 559b8e80941Smrg nir_iand(&b->nb, 560b8e80941Smrg nir_feq(&b->nb, src[0], src[0]), 561b8e80941Smrg nir_feq(&b->nb, src[1], src[1]))); 562b8e80941Smrg break; 563b8e80941Smrg } 564b8e80941Smrg 565b8e80941Smrg case SpvOpFConvert: { 566b8e80941Smrg nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); 567b8e80941Smrg nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); 568b8e80941Smrg nir_rounding_mode rounding_mode = nir_rounding_mode_undef; 569b8e80941Smrg 570b8e80941Smrg vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode); 571b8e80941Smrg nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode); 572b8e80941Smrg 573b8e80941Smrg val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL); 574b8e80941Smrg break; 575b8e80941Smrg } 576b8e80941Smrg 577b8e80941Smrg case SpvOpBitFieldInsert: 578b8e80941Smrg case SpvOpBitFieldSExtract: 579b8e80941Smrg case SpvOpBitFieldUExtract: 580b8e80941Smrg case SpvOpShiftLeftLogical: 581b8e80941Smrg case SpvOpShiftRightArithmetic: 582b8e80941Smrg case SpvOpShiftRightLogical: { 583b8e80941Smrg bool swap; 584b8e80941Smrg unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type); 585b8e80941Smrg unsigned dst_bit_size = glsl_get_bit_size(type); 586b8e80941Smrg nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, 587b8e80941Smrg src0_bit_size, dst_bit_size); 588b8e80941Smrg 589b8e80941Smrg assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl || 590b8e80941Smrg op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract || 591b8e80941Smrg op == nir_op_ibitfield_extract); 592b8e80941Smrg 593b8e80941Smrg for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { 594b8e80941Smrg unsigned src_bit_size = 595b8e80941Smrg nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]); 596b8e80941Smrg if (src_bit_size == 0) 597b8e80941Smrg continue; 598b8e80941Smrg if (src_bit_size != src[i]->bit_size) { 599b8e80941Smrg assert(src_bit_size == 32); 600b8e80941Smrg /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize 601b8e80941Smrg * supported by the NIR instructions. See discussion here: 602b8e80941Smrg * 603b8e80941Smrg * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html 604b8e80941Smrg */ 605b8e80941Smrg src[i] = nir_u2u32(&b->nb, src[i]); 606b8e80941Smrg } 607b8e80941Smrg } 608b8e80941Smrg val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); 609b8e80941Smrg break; 610b8e80941Smrg } 611b8e80941Smrg 612b8e80941Smrg case SpvOpSignBitSet: { 613b8e80941Smrg unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); 614b8e80941Smrg if (src[0]->num_components == 1) 615b8e80941Smrg val->ssa->def = 616b8e80941Smrg nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1)); 617b8e80941Smrg else 618b8e80941Smrg val->ssa->def = 619b8e80941Smrg nir_ishr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1)); 620b8e80941Smrg 621b8e80941Smrg val->ssa->def = nir_i2b(&b->nb, val->ssa->def); 622b8e80941Smrg break; 623b8e80941Smrg } 624b8e80941Smrg 625b8e80941Smrg default: { 626b8e80941Smrg bool swap; 627b8e80941Smrg unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); 628b8e80941Smrg unsigned dst_bit_size = glsl_get_bit_size(type); 629b8e80941Smrg nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, 630b8e80941Smrg src_bit_size, dst_bit_size); 631b8e80941Smrg 632b8e80941Smrg if (swap) { 633b8e80941Smrg nir_ssa_def *tmp = src[0]; 634b8e80941Smrg src[0] = src[1]; 635b8e80941Smrg src[1] = tmp; 636b8e80941Smrg } 637b8e80941Smrg 638b8e80941Smrg switch (op) { 639b8e80941Smrg case nir_op_ishl: 640b8e80941Smrg case nir_op_ishr: 641b8e80941Smrg case nir_op_ushr: 642b8e80941Smrg if (src[1]->bit_size != 32) 643b8e80941Smrg src[1] = nir_u2u32(&b->nb, src[1]); 644b8e80941Smrg break; 645b8e80941Smrg default: 646b8e80941Smrg break; 647b8e80941Smrg } 648b8e80941Smrg 649b8e80941Smrg val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); 650b8e80941Smrg break; 651b8e80941Smrg } /* default */ 652b8e80941Smrg } 653b8e80941Smrg 654b8e80941Smrg b->nb.exact = b->exact; 655b8e80941Smrg} 656b8e80941Smrg 657b8e80941Smrgvoid 658b8e80941Smrgvtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) 659b8e80941Smrg{ 660b8e80941Smrg vtn_assert(count == 4); 661b8e80941Smrg /* From the definition of OpBitcast in the SPIR-V 1.2 spec: 662b8e80941Smrg * 663b8e80941Smrg * "If Result Type has the same number of components as Operand, they 664b8e80941Smrg * must also have the same component width, and results are computed per 665b8e80941Smrg * component. 666b8e80941Smrg * 667b8e80941Smrg * If Result Type has a different number of components than Operand, the 668b8e80941Smrg * total number of bits in Result Type must equal the total number of 669b8e80941Smrg * bits in Operand. Let L be the type, either Result Type or Operand’s 670b8e80941Smrg * type, that has the larger number of components. Let S be the other 671b8e80941Smrg * type, with the smaller number of components. The number of components 672b8e80941Smrg * in L must be an integer multiple of the number of components in S. 673b8e80941Smrg * The first component (that is, the only or lowest-numbered component) 674b8e80941Smrg * of S maps to the first components of L, and so on, up to the last 675b8e80941Smrg * component of S mapping to the last components of L. Within this 676b8e80941Smrg * mapping, any single component of S (mapping to multiple components of 677b8e80941Smrg * L) maps its lower-ordered bits to the lower-numbered components of L." 678b8e80941Smrg */ 679b8e80941Smrg 680b8e80941Smrg struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type; 681b8e80941Smrg struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]); 682b8e80941Smrg struct nir_ssa_def *src = vtn_src->def; 683b8e80941Smrg struct vtn_ssa_value *val = vtn_create_ssa_value(b, type->type); 684b8e80941Smrg 685b8e80941Smrg vtn_assert(glsl_type_is_vector_or_scalar(vtn_src->type)); 686b8e80941Smrg 687b8e80941Smrg vtn_fail_if(src->num_components * src->bit_size != 688b8e80941Smrg glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type), 689b8e80941Smrg "Source and destination of OpBitcast must have the same " 690b8e80941Smrg "total number of bits"); 691b8e80941Smrg val->def = nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type)); 692b8e80941Smrg vtn_push_ssa(b, w[2], type, val); 693b8e80941Smrg} 694