1/* 2 * Copyright © 2017 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#include "nir.h" 25#include "nir_builder.h" 26 27/** 28 * \file nir_opt_intrinsics.c 29 */ 30 31static nir_intrinsic_instr * 32lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin, 33 unsigned int component) 34{ 35 nir_ssa_def *comp; 36 if (component == 0) 37 comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa); 38 else 39 comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa); 40 41 nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic); 42 nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL); 43 intr->const_index[0] = intrin->const_index[0]; 44 intr->const_index[1] = intrin->const_index[1]; 45 intr->src[0] = nir_src_for_ssa(comp); 46 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2) 47 nir_src_copy(&intr->src[1], &intrin->src[1], intr); 48 49 intr->num_components = 1; 50 nir_builder_instr_insert(b, &intr->instr); 51 return intr; 52} 53 54static nir_ssa_def * 55lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin) 56{ 57 assert(intrin->src[0].ssa->bit_size == 64); 58 nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0); 59 nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1); 60 return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa); 61} 62 63static nir_ssa_def * 64ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size) 65{ 66 /* We only use this on uvec4 types */ 67 assert(value->num_components == 4 && value->bit_size == 32); 68 69 if (bit_size == 32) { 70 return nir_channel(b, value, 0); 71 } else { 72 assert(bit_size == 64); 73 return nir_pack_64_2x32_split(b, nir_channel(b, value, 0), 74 nir_channel(b, value, 1)); 75 } 76} 77 78/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */ 79static nir_ssa_def * 80uint_to_ballot_type(nir_builder *b, nir_ssa_def *value, 81 unsigned num_components, unsigned bit_size) 82{ 83 assert(value->num_components == 1); 84 assert(value->bit_size == 32 || value->bit_size == 64); 85 86 nir_ssa_def *zero = nir_imm_int(b, 0); 87 if (num_components > 1) { 88 /* SPIR-V uses a uvec4 for ballot values */ 89 assert(num_components == 4); 90 assert(bit_size == 32); 91 92 if (value->bit_size == 32) { 93 return nir_vec4(b, value, zero, zero, zero); 94 } else { 95 assert(value->bit_size == 64); 96 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value), 97 nir_unpack_64_2x32_split_y(b, value), 98 zero, zero); 99 } 100 } else { 101 /* GLSL uses a uint64_t for ballot values */ 102 assert(num_components == 1); 103 assert(bit_size == 64); 104 105 if (value->bit_size == 32) { 106 return nir_pack_64_2x32_split(b, value, zero); 107 } else { 108 assert(value->bit_size == 64); 109 return value; 110 } 111 } 112} 113 114static nir_ssa_def * 115lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin, 116 bool lower_to_32bit) 117{ 118 /* This is safe to call on scalar things but it would be silly */ 119 assert(intrin->dest.ssa.num_components > 1); 120 121 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0], 122 intrin->num_components); 123 nir_ssa_def *reads[4]; 124 125 for (unsigned i = 0; i < intrin->num_components; i++) { 126 nir_intrinsic_instr *chan_intrin = 127 nir_intrinsic_instr_create(b->shader, intrin->intrinsic); 128 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest, 129 1, intrin->dest.ssa.bit_size, NULL); 130 chan_intrin->num_components = 1; 131 132 /* value */ 133 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i)); 134 /* invocation */ 135 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) { 136 assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2); 137 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin); 138 } 139 140 chan_intrin->const_index[0] = intrin->const_index[0]; 141 chan_intrin->const_index[1] = intrin->const_index[1]; 142 143 if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) { 144 reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin); 145 } else { 146 nir_builder_instr_insert(b, &chan_intrin->instr); 147 reads[i] = &chan_intrin->dest.ssa; 148 } 149 } 150 151 return nir_vec(b, reads, intrin->num_components); 152} 153 154static nir_ssa_def * 155lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) 156{ 157 assert(intrin->src[0].is_ssa); 158 nir_ssa_def *value = intrin->src[0].ssa; 159 160 nir_ssa_def *result = NULL; 161 for (unsigned i = 0; i < intrin->num_components; i++) { 162 nir_intrinsic_instr *chan_intrin = 163 nir_intrinsic_instr_create(b->shader, intrin->intrinsic); 164 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest, 165 1, intrin->dest.ssa.bit_size, NULL); 166 chan_intrin->num_components = 1; 167 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i)); 168 nir_builder_instr_insert(b, &chan_intrin->instr); 169 170 if (result) { 171 result = nir_iand(b, result, &chan_intrin->dest.ssa); 172 } else { 173 result = &chan_intrin->dest.ssa; 174 } 175 } 176 177 return result; 178} 179 180static nir_ssa_def * 181lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin, 182 const nir_lower_subgroups_options *options) 183{ 184 assert(intrin->src[0].is_ssa); 185 nir_ssa_def *value = intrin->src[0].ssa; 186 187 /* We have to implicitly lower to scalar */ 188 nir_ssa_def *all_eq = NULL; 189 for (unsigned i = 0; i < intrin->num_components; i++) { 190 nir_intrinsic_instr *rfi = 191 nir_intrinsic_instr_create(b->shader, 192 nir_intrinsic_read_first_invocation); 193 nir_ssa_dest_init(&rfi->instr, &rfi->dest, 194 1, value->bit_size, NULL); 195 rfi->num_components = 1; 196 rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i)); 197 nir_builder_instr_insert(b, &rfi->instr); 198 199 nir_ssa_def *is_eq; 200 if (intrin->intrinsic == nir_intrinsic_vote_feq) { 201 is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i)); 202 } else { 203 is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i)); 204 } 205 206 if (all_eq == NULL) { 207 all_eq = is_eq; 208 } else { 209 all_eq = nir_iand(b, all_eq, is_eq); 210 } 211 } 212 213 nir_intrinsic_instr *ballot = 214 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot); 215 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 216 1, options->ballot_bit_size, NULL); 217 ballot->num_components = 1; 218 ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq)); 219 nir_builder_instr_insert(b, &ballot->instr); 220 221 return nir_ieq(b, &ballot->dest.ssa, 222 nir_imm_intN_t(b, 0, options->ballot_bit_size)); 223} 224 225static nir_ssa_def * 226lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin, 227 bool lower_to_scalar, bool lower_to_32bit) 228{ 229 nir_ssa_def *index = nir_load_subgroup_invocation(b); 230 switch (intrin->intrinsic) { 231 case nir_intrinsic_shuffle_xor: 232 assert(intrin->src[1].is_ssa); 233 index = nir_ixor(b, index, intrin->src[1].ssa); 234 break; 235 case nir_intrinsic_shuffle_up: 236 assert(intrin->src[1].is_ssa); 237 index = nir_isub(b, index, intrin->src[1].ssa); 238 break; 239 case nir_intrinsic_shuffle_down: 240 assert(intrin->src[1].is_ssa); 241 index = nir_iadd(b, index, intrin->src[1].ssa); 242 break; 243 case nir_intrinsic_quad_broadcast: 244 assert(intrin->src[1].is_ssa); 245 index = nir_ior(b, nir_iand(b, index, nir_imm_int(b, ~0x3)), 246 intrin->src[1].ssa); 247 break; 248 case nir_intrinsic_quad_swap_horizontal: 249 /* For Quad operations, subgroups are divided into quads where 250 * (invocation % 4) is the index to a square arranged as follows: 251 * 252 * +---+---+ 253 * | 0 | 1 | 254 * +---+---+ 255 * | 2 | 3 | 256 * +---+---+ 257 */ 258 index = nir_ixor(b, index, nir_imm_int(b, 0x1)); 259 break; 260 case nir_intrinsic_quad_swap_vertical: 261 index = nir_ixor(b, index, nir_imm_int(b, 0x2)); 262 break; 263 case nir_intrinsic_quad_swap_diagonal: 264 index = nir_ixor(b, index, nir_imm_int(b, 0x3)); 265 break; 266 default: 267 unreachable("Invalid intrinsic"); 268 } 269 270 nir_intrinsic_instr *shuffle = 271 nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle); 272 shuffle->num_components = intrin->num_components; 273 nir_src_copy(&shuffle->src[0], &intrin->src[0], shuffle); 274 shuffle->src[1] = nir_src_for_ssa(index); 275 nir_ssa_dest_init(&shuffle->instr, &shuffle->dest, 276 intrin->dest.ssa.num_components, 277 intrin->dest.ssa.bit_size, NULL); 278 279 if (lower_to_scalar && shuffle->num_components > 1) { 280 return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit); 281 } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) { 282 return lower_subgroup_op_to_32bit(b, shuffle); 283 } else { 284 nir_builder_instr_insert(b, &shuffle->instr); 285 return &shuffle->dest.ssa; 286 } 287} 288 289static nir_ssa_def * 290lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, 291 const nir_lower_subgroups_options *options) 292{ 293 switch (intrin->intrinsic) { 294 case nir_intrinsic_vote_any: 295 case nir_intrinsic_vote_all: 296 if (options->lower_vote_trivial) 297 return nir_ssa_for_src(b, intrin->src[0], 1); 298 break; 299 300 case nir_intrinsic_vote_feq: 301 case nir_intrinsic_vote_ieq: 302 if (options->lower_vote_trivial) 303 return nir_imm_true(b); 304 305 if (options->lower_vote_eq_to_ballot) 306 return lower_vote_eq_to_ballot(b, intrin, options); 307 308 if (options->lower_to_scalar && intrin->num_components > 1) 309 return lower_vote_eq_to_scalar(b, intrin); 310 break; 311 312 case nir_intrinsic_load_subgroup_size: 313 if (options->subgroup_size) 314 return nir_imm_int(b, options->subgroup_size); 315 break; 316 317 case nir_intrinsic_read_invocation: 318 case nir_intrinsic_read_first_invocation: 319 if (options->lower_to_scalar && intrin->num_components > 1) 320 return lower_subgroup_op_to_scalar(b, intrin, false); 321 break; 322 323 case nir_intrinsic_load_subgroup_eq_mask: 324 case nir_intrinsic_load_subgroup_ge_mask: 325 case nir_intrinsic_load_subgroup_gt_mask: 326 case nir_intrinsic_load_subgroup_le_mask: 327 case nir_intrinsic_load_subgroup_lt_mask: { 328 if (!options->lower_subgroup_masks) 329 return NULL; 330 331 /* If either the result or the requested bit size is 64-bits then we 332 * know that we have 64-bit types and using them will probably be more 333 * efficient than messing around with 32-bit shifts and packing. 334 */ 335 const unsigned bit_size = MAX2(options->ballot_bit_size, 336 intrin->dest.ssa.bit_size); 337 338 assert(options->subgroup_size <= 64); 339 uint64_t group_mask = ~0ull >> (64 - options->subgroup_size); 340 341 nir_ssa_def *count = nir_load_subgroup_invocation(b); 342 nir_ssa_def *val; 343 switch (intrin->intrinsic) { 344 case nir_intrinsic_load_subgroup_eq_mask: 345 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count); 346 break; 347 case nir_intrinsic_load_subgroup_ge_mask: 348 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count), 349 nir_imm_intN_t(b, group_mask, bit_size)); 350 break; 351 case nir_intrinsic_load_subgroup_gt_mask: 352 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count), 353 nir_imm_intN_t(b, group_mask, bit_size)); 354 break; 355 case nir_intrinsic_load_subgroup_le_mask: 356 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count)); 357 break; 358 case nir_intrinsic_load_subgroup_lt_mask: 359 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count)); 360 break; 361 default: 362 unreachable("you seriously can't tell this is unreachable?"); 363 } 364 365 return uint_to_ballot_type(b, val, 366 intrin->dest.ssa.num_components, 367 intrin->dest.ssa.bit_size); 368 } 369 370 case nir_intrinsic_ballot: { 371 if (intrin->dest.ssa.num_components == 1 && 372 intrin->dest.ssa.bit_size == options->ballot_bit_size) 373 return NULL; 374 375 nir_intrinsic_instr *ballot = 376 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot); 377 ballot->num_components = 1; 378 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 379 1, options->ballot_bit_size, NULL); 380 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot); 381 nir_builder_instr_insert(b, &ballot->instr); 382 383 return uint_to_ballot_type(b, &ballot->dest.ssa, 384 intrin->dest.ssa.num_components, 385 intrin->dest.ssa.bit_size); 386 } 387 388 case nir_intrinsic_ballot_bitfield_extract: 389 case nir_intrinsic_ballot_bit_count_reduce: 390 case nir_intrinsic_ballot_find_lsb: 391 case nir_intrinsic_ballot_find_msb: { 392 assert(intrin->src[0].is_ssa); 393 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, 394 options->ballot_bit_size); 395 switch (intrin->intrinsic) { 396 case nir_intrinsic_ballot_bitfield_extract: 397 assert(intrin->src[1].is_ssa); 398 return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val, 399 intrin->src[1].ssa), 400 nir_imm_intN_t(b, 1, options->ballot_bit_size))); 401 case nir_intrinsic_ballot_bit_count_reduce: 402 return nir_bit_count(b, int_val); 403 case nir_intrinsic_ballot_find_lsb: 404 return nir_find_lsb(b, int_val); 405 case nir_intrinsic_ballot_find_msb: 406 return nir_ufind_msb(b, int_val); 407 default: 408 unreachable("you seriously can't tell this is unreachable?"); 409 } 410 } 411 412 case nir_intrinsic_ballot_bit_count_exclusive: 413 case nir_intrinsic_ballot_bit_count_inclusive: { 414 nir_ssa_def *count = nir_load_subgroup_invocation(b); 415 nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size); 416 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) { 417 const unsigned bits = options->ballot_bit_size; 418 mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count)); 419 } else { 420 mask = nir_inot(b, nir_ishl(b, mask, count)); 421 } 422 423 assert(intrin->src[0].is_ssa); 424 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, 425 options->ballot_bit_size); 426 427 return nir_bit_count(b, nir_iand(b, int_val, mask)); 428 } 429 430 case nir_intrinsic_elect: { 431 nir_intrinsic_instr *first = 432 nir_intrinsic_instr_create(b->shader, 433 nir_intrinsic_first_invocation); 434 nir_ssa_dest_init(&first->instr, &first->dest, 1, 32, NULL); 435 nir_builder_instr_insert(b, &first->instr); 436 437 return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa); 438 } 439 440 case nir_intrinsic_shuffle: 441 if (options->lower_to_scalar && intrin->num_components > 1) 442 return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit); 443 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) 444 return lower_subgroup_op_to_32bit(b, intrin); 445 break; 446 447 case nir_intrinsic_shuffle_xor: 448 case nir_intrinsic_shuffle_up: 449 case nir_intrinsic_shuffle_down: 450 if (options->lower_shuffle) 451 return lower_shuffle(b, intrin, options->lower_to_scalar, options->lower_shuffle_to_32bit); 452 else if (options->lower_to_scalar && intrin->num_components > 1) 453 return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit); 454 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) 455 return lower_subgroup_op_to_32bit(b, intrin); 456 break; 457 458 case nir_intrinsic_quad_broadcast: 459 case nir_intrinsic_quad_swap_horizontal: 460 case nir_intrinsic_quad_swap_vertical: 461 case nir_intrinsic_quad_swap_diagonal: 462 if (options->lower_quad) 463 return lower_shuffle(b, intrin, options->lower_to_scalar, false); 464 else if (options->lower_to_scalar && intrin->num_components > 1) 465 return lower_subgroup_op_to_scalar(b, intrin, false); 466 break; 467 468 case nir_intrinsic_reduce: 469 case nir_intrinsic_inclusive_scan: 470 case nir_intrinsic_exclusive_scan: 471 if (options->lower_to_scalar && intrin->num_components > 1) 472 return lower_subgroup_op_to_scalar(b, intrin, false); 473 break; 474 475 default: 476 break; 477 } 478 479 return NULL; 480} 481 482static bool 483lower_subgroups_impl(nir_function_impl *impl, 484 const nir_lower_subgroups_options *options) 485{ 486 nir_builder b; 487 nir_builder_init(&b, impl); 488 bool progress = false; 489 490 nir_foreach_block(block, impl) { 491 nir_foreach_instr_safe(instr, block) { 492 if (instr->type != nir_instr_type_intrinsic) 493 continue; 494 495 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 496 b.cursor = nir_before_instr(instr); 497 498 nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options); 499 if (!lower) 500 continue; 501 502 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower)); 503 nir_instr_remove(instr); 504 progress = true; 505 } 506 } 507 508 return progress; 509} 510 511bool 512nir_lower_subgroups(nir_shader *shader, 513 const nir_lower_subgroups_options *options) 514{ 515 bool progress = false; 516 517 nir_foreach_function(function, shader) { 518 if (!function->impl) 519 continue; 520 521 if (lower_subgroups_impl(function->impl, options)) { 522 progress = true; 523 nir_metadata_preserve(function->impl, nir_metadata_block_index | 524 nir_metadata_dominance); 525 } 526 } 527 528 return progress; 529} 530