1/* 2 * Copyright © 2016 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#include "vtn_private.h" 25 26static void 27vtn_build_subgroup_instr(struct vtn_builder *b, 28 nir_intrinsic_op nir_op, 29 struct vtn_ssa_value *dst, 30 struct vtn_ssa_value *src0, 31 nir_ssa_def *index, 32 unsigned const_idx0, 33 unsigned const_idx1) 34{ 35 /* Some of the subgroup operations take an index. SPIR-V allows this to be 36 * any integer type. To make things simpler for drivers, we only support 37 * 32-bit indices. 38 */ 39 if (index && index->bit_size != 32) 40 index = nir_u2u32(&b->nb, index); 41 42 vtn_assert(dst->type == src0->type); 43 if (!glsl_type_is_vector_or_scalar(dst->type)) { 44 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) { 45 vtn_build_subgroup_instr(b, nir_op, dst->elems[i], 46 src0->elems[i], index, 47 const_idx0, const_idx1); 48 } 49 return; 50 } 51 52 nir_intrinsic_instr *intrin = 53 nir_intrinsic_instr_create(b->nb.shader, nir_op); 54 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 55 dst->type, NULL); 56 intrin->num_components = intrin->dest.ssa.num_components; 57 58 intrin->src[0] = nir_src_for_ssa(src0->def); 59 if (index) 60 intrin->src[1] = nir_src_for_ssa(index); 61 62 intrin->const_index[0] = const_idx0; 63 intrin->const_index[1] = const_idx1; 64 65 nir_builder_instr_insert(&b->nb, &intrin->instr); 66 67 dst->def = &intrin->dest.ssa; 68} 69 70void 71vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, 72 const uint32_t *w, unsigned count) 73{ 74 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); 75 76 val->ssa = vtn_create_ssa_value(b, val->type->type); 77 78 switch (opcode) { 79 case SpvOpGroupNonUniformElect: { 80 vtn_fail_if(val->type->type != glsl_bool_type(), 81 "OpGroupNonUniformElect must return a Bool"); 82 nir_intrinsic_instr *elect = 83 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect); 84 nir_ssa_dest_init_for_type(&elect->instr, &elect->dest, 85 val->type->type, NULL); 86 nir_builder_instr_insert(&b->nb, &elect->instr); 87 val->ssa->def = &elect->dest.ssa; 88 break; 89 } 90 91 case SpvOpGroupNonUniformBallot: { 92 vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4), 93 "OpGroupNonUniformBallot must return a uvec4"); 94 nir_intrinsic_instr *ballot = 95 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot); 96 ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); 97 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL); 98 ballot->num_components = 4; 99 nir_builder_instr_insert(&b->nb, &ballot->instr); 100 val->ssa->def = &ballot->dest.ssa; 101 break; 102 } 103 104 case SpvOpGroupNonUniformInverseBallot: { 105 /* This one is just a BallotBitfieldExtract with subgroup invocation. 106 * We could add a NIR intrinsic but it's easier to just lower it on the 107 * spot. 108 */ 109 nir_intrinsic_instr *intrin = 110 nir_intrinsic_instr_create(b->nb.shader, 111 nir_intrinsic_ballot_bitfield_extract); 112 113 intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); 114 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb)); 115 116 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 117 val->type->type, NULL); 118 nir_builder_instr_insert(&b->nb, &intrin->instr); 119 120 val->ssa->def = &intrin->dest.ssa; 121 break; 122 } 123 124 case SpvOpGroupNonUniformBallotBitExtract: 125 case SpvOpGroupNonUniformBallotBitCount: 126 case SpvOpGroupNonUniformBallotFindLSB: 127 case SpvOpGroupNonUniformBallotFindMSB: { 128 nir_ssa_def *src0, *src1 = NULL; 129 nir_intrinsic_op op; 130 switch (opcode) { 131 case SpvOpGroupNonUniformBallotBitExtract: 132 op = nir_intrinsic_ballot_bitfield_extract; 133 src0 = vtn_ssa_value(b, w[4])->def; 134 src1 = vtn_ssa_value(b, w[5])->def; 135 break; 136 case SpvOpGroupNonUniformBallotBitCount: 137 switch ((SpvGroupOperation)w[4]) { 138 case SpvGroupOperationReduce: 139 op = nir_intrinsic_ballot_bit_count_reduce; 140 break; 141 case SpvGroupOperationInclusiveScan: 142 op = nir_intrinsic_ballot_bit_count_inclusive; 143 break; 144 case SpvGroupOperationExclusiveScan: 145 op = nir_intrinsic_ballot_bit_count_exclusive; 146 break; 147 default: 148 unreachable("Invalid group operation"); 149 } 150 src0 = vtn_ssa_value(b, w[5])->def; 151 break; 152 case SpvOpGroupNonUniformBallotFindLSB: 153 op = nir_intrinsic_ballot_find_lsb; 154 src0 = vtn_ssa_value(b, w[4])->def; 155 break; 156 case SpvOpGroupNonUniformBallotFindMSB: 157 op = nir_intrinsic_ballot_find_msb; 158 src0 = vtn_ssa_value(b, w[4])->def; 159 break; 160 default: 161 unreachable("Unhandled opcode"); 162 } 163 164 nir_intrinsic_instr *intrin = 165 nir_intrinsic_instr_create(b->nb.shader, op); 166 167 intrin->src[0] = nir_src_for_ssa(src0); 168 if (src1) 169 intrin->src[1] = nir_src_for_ssa(src1); 170 171 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 172 val->type->type, NULL); 173 nir_builder_instr_insert(&b->nb, &intrin->instr); 174 175 val->ssa->def = &intrin->dest.ssa; 176 break; 177 } 178 179 case SpvOpGroupNonUniformBroadcastFirst: 180 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, 181 val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0); 182 break; 183 184 case SpvOpGroupNonUniformBroadcast: 185 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, 186 val->ssa, vtn_ssa_value(b, w[4]), 187 vtn_ssa_value(b, w[5])->def, 0, 0); 188 break; 189 190 case SpvOpGroupNonUniformAll: 191 case SpvOpGroupNonUniformAny: 192 case SpvOpGroupNonUniformAllEqual: { 193 vtn_fail_if(val->type->type != glsl_bool_type(), 194 "OpGroupNonUniform(All|Any|AllEqual) must return a bool"); 195 nir_intrinsic_op op; 196 switch (opcode) { 197 case SpvOpGroupNonUniformAll: 198 op = nir_intrinsic_vote_all; 199 break; 200 case SpvOpGroupNonUniformAny: 201 op = nir_intrinsic_vote_any; 202 break; 203 case SpvOpGroupNonUniformAllEqual: { 204 switch (glsl_get_base_type(val->type->type)) { 205 case GLSL_TYPE_FLOAT: 206 case GLSL_TYPE_DOUBLE: 207 op = nir_intrinsic_vote_feq; 208 break; 209 case GLSL_TYPE_UINT: 210 case GLSL_TYPE_INT: 211 case GLSL_TYPE_UINT64: 212 case GLSL_TYPE_INT64: 213 case GLSL_TYPE_BOOL: 214 op = nir_intrinsic_vote_ieq; 215 break; 216 default: 217 unreachable("Unhandled type"); 218 } 219 break; 220 } 221 default: 222 unreachable("Unhandled opcode"); 223 } 224 225 nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def; 226 227 nir_intrinsic_instr *intrin = 228 nir_intrinsic_instr_create(b->nb.shader, op); 229 intrin->num_components = src0->num_components; 230 intrin->src[0] = nir_src_for_ssa(src0); 231 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 232 val->type->type, NULL); 233 nir_builder_instr_insert(&b->nb, &intrin->instr); 234 235 val->ssa->def = &intrin->dest.ssa; 236 break; 237 } 238 239 case SpvOpGroupNonUniformShuffle: 240 case SpvOpGroupNonUniformShuffleXor: 241 case SpvOpGroupNonUniformShuffleUp: 242 case SpvOpGroupNonUniformShuffleDown: { 243 nir_intrinsic_op op; 244 switch (opcode) { 245 case SpvOpGroupNonUniformShuffle: 246 op = nir_intrinsic_shuffle; 247 break; 248 case SpvOpGroupNonUniformShuffleXor: 249 op = nir_intrinsic_shuffle_xor; 250 break; 251 case SpvOpGroupNonUniformShuffleUp: 252 op = nir_intrinsic_shuffle_up; 253 break; 254 case SpvOpGroupNonUniformShuffleDown: 255 op = nir_intrinsic_shuffle_down; 256 break; 257 default: 258 unreachable("Invalid opcode"); 259 } 260 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), 261 vtn_ssa_value(b, w[5])->def, 0, 0); 262 break; 263 } 264 265 case SpvOpGroupNonUniformQuadBroadcast: 266 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, 267 val->ssa, vtn_ssa_value(b, w[4]), 268 vtn_ssa_value(b, w[5])->def, 0, 0); 269 break; 270 271 case SpvOpGroupNonUniformQuadSwap: { 272 unsigned direction = vtn_constant_uint(b, w[5]); 273 nir_intrinsic_op op; 274 switch (direction) { 275 case 0: 276 op = nir_intrinsic_quad_swap_horizontal; 277 break; 278 case 1: 279 op = nir_intrinsic_quad_swap_vertical; 280 break; 281 case 2: 282 op = nir_intrinsic_quad_swap_diagonal; 283 break; 284 default: 285 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap"); 286 } 287 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), 288 NULL, 0, 0); 289 break; 290 } 291 292 case SpvOpGroupNonUniformIAdd: 293 case SpvOpGroupNonUniformFAdd: 294 case SpvOpGroupNonUniformIMul: 295 case SpvOpGroupNonUniformFMul: 296 case SpvOpGroupNonUniformSMin: 297 case SpvOpGroupNonUniformUMin: 298 case SpvOpGroupNonUniformFMin: 299 case SpvOpGroupNonUniformSMax: 300 case SpvOpGroupNonUniformUMax: 301 case SpvOpGroupNonUniformFMax: 302 case SpvOpGroupNonUniformBitwiseAnd: 303 case SpvOpGroupNonUniformBitwiseOr: 304 case SpvOpGroupNonUniformBitwiseXor: 305 case SpvOpGroupNonUniformLogicalAnd: 306 case SpvOpGroupNonUniformLogicalOr: 307 case SpvOpGroupNonUniformLogicalXor: { 308 nir_op reduction_op; 309 switch (opcode) { 310 case SpvOpGroupNonUniformIAdd: 311 reduction_op = nir_op_iadd; 312 break; 313 case SpvOpGroupNonUniformFAdd: 314 reduction_op = nir_op_fadd; 315 break; 316 case SpvOpGroupNonUniformIMul: 317 reduction_op = nir_op_imul; 318 break; 319 case SpvOpGroupNonUniformFMul: 320 reduction_op = nir_op_fmul; 321 break; 322 case SpvOpGroupNonUniformSMin: 323 reduction_op = nir_op_imin; 324 break; 325 case SpvOpGroupNonUniformUMin: 326 reduction_op = nir_op_umin; 327 break; 328 case SpvOpGroupNonUniformFMin: 329 reduction_op = nir_op_fmin; 330 break; 331 case SpvOpGroupNonUniformSMax: 332 reduction_op = nir_op_imax; 333 break; 334 case SpvOpGroupNonUniformUMax: 335 reduction_op = nir_op_umax; 336 break; 337 case SpvOpGroupNonUniformFMax: 338 reduction_op = nir_op_fmax; 339 break; 340 case SpvOpGroupNonUniformBitwiseAnd: 341 case SpvOpGroupNonUniformLogicalAnd: 342 reduction_op = nir_op_iand; 343 break; 344 case SpvOpGroupNonUniformBitwiseOr: 345 case SpvOpGroupNonUniformLogicalOr: 346 reduction_op = nir_op_ior; 347 break; 348 case SpvOpGroupNonUniformBitwiseXor: 349 case SpvOpGroupNonUniformLogicalXor: 350 reduction_op = nir_op_ixor; 351 break; 352 default: 353 unreachable("Invalid reduction operation"); 354 } 355 356 nir_intrinsic_op op; 357 unsigned cluster_size = 0; 358 switch ((SpvGroupOperation)w[4]) { 359 case SpvGroupOperationReduce: 360 op = nir_intrinsic_reduce; 361 break; 362 case SpvGroupOperationInclusiveScan: 363 op = nir_intrinsic_inclusive_scan; 364 break; 365 case SpvGroupOperationExclusiveScan: 366 op = nir_intrinsic_exclusive_scan; 367 break; 368 case SpvGroupOperationClusteredReduce: 369 op = nir_intrinsic_reduce; 370 assert(count == 7); 371 cluster_size = vtn_constant_uint(b, w[6]); 372 break; 373 default: 374 unreachable("Invalid group operation"); 375 } 376 377 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]), 378 NULL, reduction_op, cluster_size); 379 break; 380 } 381 382 default: 383 unreachable("Invalid SPIR-V opcode"); 384 } 385} 386