1b8e80941Smrg/* 2b8e80941Smrg * Copyright © 2017 Intel Corporation 3b8e80941Smrg * 4b8e80941Smrg * Permission is hereby granted, free of charge, to any person obtaining a 5b8e80941Smrg * copy of this software and associated documentation files (the "Software"), 6b8e80941Smrg * to deal in the Software without restriction, including without limitation 7b8e80941Smrg * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8b8e80941Smrg * and/or sell copies of the Software, and to permit persons to whom the 9b8e80941Smrg * Software is furnished to do so, subject to the following conditions: 10b8e80941Smrg * 11b8e80941Smrg * The above copyright notice and this permission notice (including the next 12b8e80941Smrg * paragraph) shall be included in all copies or substantial portions of the 13b8e80941Smrg * Software. 14b8e80941Smrg * 15b8e80941Smrg * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16b8e80941Smrg * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17b8e80941Smrg * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18b8e80941Smrg * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19b8e80941Smrg * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20b8e80941Smrg * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21b8e80941Smrg * IN THE SOFTWARE. 22b8e80941Smrg */ 23b8e80941Smrg 24b8e80941Smrg#include "nir.h" 25b8e80941Smrg#include "nir_builder.h" 26b8e80941Smrg 27b8e80941Smrg/** 28b8e80941Smrg * \file nir_opt_intrinsics.c 29b8e80941Smrg */ 30b8e80941Smrg 31b8e80941Smrgstatic nir_intrinsic_instr * 32b8e80941Smrglower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin, 33b8e80941Smrg unsigned int component) 34b8e80941Smrg{ 35b8e80941Smrg nir_ssa_def *comp; 36b8e80941Smrg if (component == 0) 37b8e80941Smrg comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa); 38b8e80941Smrg else 39b8e80941Smrg comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa); 40b8e80941Smrg 41b8e80941Smrg nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic); 42b8e80941Smrg nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL); 43b8e80941Smrg intr->const_index[0] = intrin->const_index[0]; 44b8e80941Smrg intr->const_index[1] = intrin->const_index[1]; 45b8e80941Smrg intr->src[0] = nir_src_for_ssa(comp); 46b8e80941Smrg if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2) 47b8e80941Smrg nir_src_copy(&intr->src[1], &intrin->src[1], intr); 48b8e80941Smrg 49b8e80941Smrg intr->num_components = 1; 50b8e80941Smrg nir_builder_instr_insert(b, &intr->instr); 51b8e80941Smrg return intr; 52b8e80941Smrg} 53b8e80941Smrg 54b8e80941Smrgstatic nir_ssa_def * 55b8e80941Smrglower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin) 56b8e80941Smrg{ 57b8e80941Smrg assert(intrin->src[0].ssa->bit_size == 64); 58b8e80941Smrg nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0); 59b8e80941Smrg nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1); 60b8e80941Smrg return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa); 61b8e80941Smrg} 62b8e80941Smrg 63b8e80941Smrgstatic nir_ssa_def * 64b8e80941Smrgballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size) 65b8e80941Smrg{ 66b8e80941Smrg /* We only use this on uvec4 types */ 67b8e80941Smrg assert(value->num_components == 4 && value->bit_size == 32); 68b8e80941Smrg 69b8e80941Smrg if (bit_size == 32) { 70b8e80941Smrg return nir_channel(b, value, 0); 71b8e80941Smrg } else { 72b8e80941Smrg assert(bit_size == 64); 73b8e80941Smrg return nir_pack_64_2x32_split(b, nir_channel(b, value, 0), 74b8e80941Smrg nir_channel(b, value, 1)); 75b8e80941Smrg } 76b8e80941Smrg} 77b8e80941Smrg 78b8e80941Smrg/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */ 79b8e80941Smrgstatic nir_ssa_def * 80b8e80941Smrguint_to_ballot_type(nir_builder *b, nir_ssa_def *value, 81b8e80941Smrg unsigned num_components, unsigned bit_size) 82b8e80941Smrg{ 83b8e80941Smrg assert(value->num_components == 1); 84b8e80941Smrg assert(value->bit_size == 32 || value->bit_size == 64); 85b8e80941Smrg 86b8e80941Smrg nir_ssa_def *zero = nir_imm_int(b, 0); 87b8e80941Smrg if (num_components > 1) { 88b8e80941Smrg /* SPIR-V uses a uvec4 for ballot values */ 89b8e80941Smrg assert(num_components == 4); 90b8e80941Smrg assert(bit_size == 32); 91b8e80941Smrg 92b8e80941Smrg if (value->bit_size == 32) { 93b8e80941Smrg return nir_vec4(b, value, zero, zero, zero); 94b8e80941Smrg } else { 95b8e80941Smrg assert(value->bit_size == 64); 96b8e80941Smrg return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value), 97b8e80941Smrg nir_unpack_64_2x32_split_y(b, value), 98b8e80941Smrg zero, zero); 99b8e80941Smrg } 100b8e80941Smrg } else { 101b8e80941Smrg /* GLSL uses a uint64_t for ballot values */ 102b8e80941Smrg assert(num_components == 1); 103b8e80941Smrg assert(bit_size == 64); 104b8e80941Smrg 105b8e80941Smrg if (value->bit_size == 32) { 106b8e80941Smrg return nir_pack_64_2x32_split(b, value, zero); 107b8e80941Smrg } else { 108b8e80941Smrg assert(value->bit_size == 64); 109b8e80941Smrg return value; 110b8e80941Smrg } 111b8e80941Smrg } 112b8e80941Smrg} 113b8e80941Smrg 114b8e80941Smrgstatic nir_ssa_def * 115b8e80941Smrglower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin, 116b8e80941Smrg bool lower_to_32bit) 117b8e80941Smrg{ 118b8e80941Smrg /* This is safe to call on scalar things but it would be silly */ 119b8e80941Smrg assert(intrin->dest.ssa.num_components > 1); 120b8e80941Smrg 121b8e80941Smrg nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0], 122b8e80941Smrg intrin->num_components); 123b8e80941Smrg nir_ssa_def *reads[4]; 124b8e80941Smrg 125b8e80941Smrg for (unsigned i = 0; i < intrin->num_components; i++) { 126b8e80941Smrg nir_intrinsic_instr *chan_intrin = 127b8e80941Smrg nir_intrinsic_instr_create(b->shader, intrin->intrinsic); 128b8e80941Smrg nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest, 129b8e80941Smrg 1, intrin->dest.ssa.bit_size, NULL); 130b8e80941Smrg chan_intrin->num_components = 1; 131b8e80941Smrg 132b8e80941Smrg /* value */ 133b8e80941Smrg chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i)); 134b8e80941Smrg /* invocation */ 135b8e80941Smrg if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) { 136b8e80941Smrg assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2); 137b8e80941Smrg nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin); 138b8e80941Smrg } 139b8e80941Smrg 140b8e80941Smrg chan_intrin->const_index[0] = intrin->const_index[0]; 141b8e80941Smrg chan_intrin->const_index[1] = intrin->const_index[1]; 142b8e80941Smrg 143b8e80941Smrg if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) { 144b8e80941Smrg reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin); 145b8e80941Smrg } else { 146b8e80941Smrg nir_builder_instr_insert(b, &chan_intrin->instr); 147b8e80941Smrg reads[i] = &chan_intrin->dest.ssa; 148b8e80941Smrg } 149b8e80941Smrg } 150b8e80941Smrg 151b8e80941Smrg return nir_vec(b, reads, intrin->num_components); 152b8e80941Smrg} 153b8e80941Smrg 154b8e80941Smrgstatic nir_ssa_def * 155b8e80941Smrglower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) 156b8e80941Smrg{ 157b8e80941Smrg assert(intrin->src[0].is_ssa); 158b8e80941Smrg nir_ssa_def *value = intrin->src[0].ssa; 159b8e80941Smrg 160b8e80941Smrg nir_ssa_def *result = NULL; 161b8e80941Smrg for (unsigned i = 0; i < intrin->num_components; i++) { 162b8e80941Smrg nir_intrinsic_instr *chan_intrin = 163b8e80941Smrg nir_intrinsic_instr_create(b->shader, intrin->intrinsic); 164b8e80941Smrg nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest, 165b8e80941Smrg 1, intrin->dest.ssa.bit_size, NULL); 166b8e80941Smrg chan_intrin->num_components = 1; 167b8e80941Smrg chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i)); 168b8e80941Smrg nir_builder_instr_insert(b, &chan_intrin->instr); 169b8e80941Smrg 170b8e80941Smrg if (result) { 171b8e80941Smrg result = nir_iand(b, result, &chan_intrin->dest.ssa); 172b8e80941Smrg } else { 173b8e80941Smrg result = &chan_intrin->dest.ssa; 174b8e80941Smrg } 175b8e80941Smrg } 176b8e80941Smrg 177b8e80941Smrg return result; 178b8e80941Smrg} 179b8e80941Smrg 180b8e80941Smrgstatic nir_ssa_def * 181b8e80941Smrglower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin, 182b8e80941Smrg const nir_lower_subgroups_options *options) 183b8e80941Smrg{ 184b8e80941Smrg assert(intrin->src[0].is_ssa); 185b8e80941Smrg nir_ssa_def *value = intrin->src[0].ssa; 186b8e80941Smrg 187b8e80941Smrg /* We have to implicitly lower to scalar */ 188b8e80941Smrg nir_ssa_def *all_eq = NULL; 189b8e80941Smrg for (unsigned i = 0; i < intrin->num_components; i++) { 190b8e80941Smrg nir_intrinsic_instr *rfi = 191b8e80941Smrg nir_intrinsic_instr_create(b->shader, 192b8e80941Smrg nir_intrinsic_read_first_invocation); 193b8e80941Smrg nir_ssa_dest_init(&rfi->instr, &rfi->dest, 194b8e80941Smrg 1, value->bit_size, NULL); 195b8e80941Smrg rfi->num_components = 1; 196b8e80941Smrg rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i)); 197b8e80941Smrg nir_builder_instr_insert(b, &rfi->instr); 198b8e80941Smrg 199b8e80941Smrg nir_ssa_def *is_eq; 200b8e80941Smrg if (intrin->intrinsic == nir_intrinsic_vote_feq) { 201b8e80941Smrg is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i)); 202b8e80941Smrg } else { 203b8e80941Smrg is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i)); 204b8e80941Smrg } 205b8e80941Smrg 206b8e80941Smrg if (all_eq == NULL) { 207b8e80941Smrg all_eq = is_eq; 208b8e80941Smrg } else { 209b8e80941Smrg all_eq = nir_iand(b, all_eq, is_eq); 210b8e80941Smrg } 211b8e80941Smrg } 212b8e80941Smrg 213b8e80941Smrg nir_intrinsic_instr *ballot = 214b8e80941Smrg nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot); 215b8e80941Smrg nir_ssa_dest_init(&ballot->instr, &ballot->dest, 216b8e80941Smrg 1, options->ballot_bit_size, NULL); 217b8e80941Smrg ballot->num_components = 1; 218b8e80941Smrg ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq)); 219b8e80941Smrg nir_builder_instr_insert(b, &ballot->instr); 220b8e80941Smrg 221b8e80941Smrg return nir_ieq(b, &ballot->dest.ssa, 222b8e80941Smrg nir_imm_intN_t(b, 0, options->ballot_bit_size)); 223b8e80941Smrg} 224b8e80941Smrg 225b8e80941Smrgstatic nir_ssa_def * 226b8e80941Smrglower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin, 227b8e80941Smrg bool lower_to_scalar, bool lower_to_32bit) 228b8e80941Smrg{ 229b8e80941Smrg nir_ssa_def *index = nir_load_subgroup_invocation(b); 230b8e80941Smrg switch (intrin->intrinsic) { 231b8e80941Smrg case nir_intrinsic_shuffle_xor: 232b8e80941Smrg assert(intrin->src[1].is_ssa); 233b8e80941Smrg index = nir_ixor(b, index, intrin->src[1].ssa); 234b8e80941Smrg break; 235b8e80941Smrg case nir_intrinsic_shuffle_up: 236b8e80941Smrg assert(intrin->src[1].is_ssa); 237b8e80941Smrg index = nir_isub(b, index, intrin->src[1].ssa); 238b8e80941Smrg break; 239b8e80941Smrg case nir_intrinsic_shuffle_down: 240b8e80941Smrg assert(intrin->src[1].is_ssa); 241b8e80941Smrg index = nir_iadd(b, index, intrin->src[1].ssa); 242b8e80941Smrg break; 243b8e80941Smrg case nir_intrinsic_quad_broadcast: 244b8e80941Smrg assert(intrin->src[1].is_ssa); 245b8e80941Smrg index = nir_ior(b, nir_iand(b, index, nir_imm_int(b, ~0x3)), 246b8e80941Smrg intrin->src[1].ssa); 247b8e80941Smrg break; 248b8e80941Smrg case nir_intrinsic_quad_swap_horizontal: 249b8e80941Smrg /* For Quad operations, subgroups are divided into quads where 250b8e80941Smrg * (invocation % 4) is the index to a square arranged as follows: 251b8e80941Smrg * 252b8e80941Smrg * +---+---+ 253b8e80941Smrg * | 0 | 1 | 254b8e80941Smrg * +---+---+ 255b8e80941Smrg * | 2 | 3 | 256b8e80941Smrg * +---+---+ 257b8e80941Smrg */ 258b8e80941Smrg index = nir_ixor(b, index, nir_imm_int(b, 0x1)); 259b8e80941Smrg break; 260b8e80941Smrg case nir_intrinsic_quad_swap_vertical: 261b8e80941Smrg index = nir_ixor(b, index, nir_imm_int(b, 0x2)); 262b8e80941Smrg break; 263b8e80941Smrg case nir_intrinsic_quad_swap_diagonal: 264b8e80941Smrg index = nir_ixor(b, index, nir_imm_int(b, 0x3)); 265b8e80941Smrg break; 266b8e80941Smrg default: 267b8e80941Smrg unreachable("Invalid intrinsic"); 268b8e80941Smrg } 269b8e80941Smrg 270b8e80941Smrg nir_intrinsic_instr *shuffle = 271b8e80941Smrg nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle); 272b8e80941Smrg shuffle->num_components = intrin->num_components; 273b8e80941Smrg nir_src_copy(&shuffle->src[0], &intrin->src[0], shuffle); 274b8e80941Smrg shuffle->src[1] = nir_src_for_ssa(index); 275b8e80941Smrg nir_ssa_dest_init(&shuffle->instr, &shuffle->dest, 276b8e80941Smrg intrin->dest.ssa.num_components, 277b8e80941Smrg intrin->dest.ssa.bit_size, NULL); 278b8e80941Smrg 279b8e80941Smrg if (lower_to_scalar && shuffle->num_components > 1) { 280b8e80941Smrg return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit); 281b8e80941Smrg } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) { 282b8e80941Smrg return lower_subgroup_op_to_32bit(b, shuffle); 283b8e80941Smrg } else { 284b8e80941Smrg nir_builder_instr_insert(b, &shuffle->instr); 285b8e80941Smrg return &shuffle->dest.ssa; 286b8e80941Smrg } 287b8e80941Smrg} 288b8e80941Smrg 289b8e80941Smrgstatic nir_ssa_def * 290b8e80941Smrglower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, 291b8e80941Smrg const nir_lower_subgroups_options *options) 292b8e80941Smrg{ 293b8e80941Smrg switch (intrin->intrinsic) { 294b8e80941Smrg case nir_intrinsic_vote_any: 295b8e80941Smrg case nir_intrinsic_vote_all: 296b8e80941Smrg if (options->lower_vote_trivial) 297b8e80941Smrg return nir_ssa_for_src(b, intrin->src[0], 1); 298b8e80941Smrg break; 299b8e80941Smrg 300b8e80941Smrg case nir_intrinsic_vote_feq: 301b8e80941Smrg case nir_intrinsic_vote_ieq: 302b8e80941Smrg if (options->lower_vote_trivial) 303b8e80941Smrg return nir_imm_true(b); 304b8e80941Smrg 305b8e80941Smrg if (options->lower_vote_eq_to_ballot) 306b8e80941Smrg return lower_vote_eq_to_ballot(b, intrin, options); 307b8e80941Smrg 308b8e80941Smrg if (options->lower_to_scalar && intrin->num_components > 1) 309b8e80941Smrg return lower_vote_eq_to_scalar(b, intrin); 310b8e80941Smrg break; 311b8e80941Smrg 312b8e80941Smrg case nir_intrinsic_load_subgroup_size: 313b8e80941Smrg if (options->subgroup_size) 314b8e80941Smrg return nir_imm_int(b, options->subgroup_size); 315b8e80941Smrg break; 316b8e80941Smrg 317b8e80941Smrg case nir_intrinsic_read_invocation: 318b8e80941Smrg case nir_intrinsic_read_first_invocation: 319b8e80941Smrg if (options->lower_to_scalar && intrin->num_components > 1) 320b8e80941Smrg return lower_subgroup_op_to_scalar(b, intrin, false); 321b8e80941Smrg break; 322b8e80941Smrg 323b8e80941Smrg case nir_intrinsic_load_subgroup_eq_mask: 324b8e80941Smrg case nir_intrinsic_load_subgroup_ge_mask: 325b8e80941Smrg case nir_intrinsic_load_subgroup_gt_mask: 326b8e80941Smrg case nir_intrinsic_load_subgroup_le_mask: 327b8e80941Smrg case nir_intrinsic_load_subgroup_lt_mask: { 328b8e80941Smrg if (!options->lower_subgroup_masks) 329b8e80941Smrg return NULL; 330b8e80941Smrg 331b8e80941Smrg /* If either the result or the requested bit size is 64-bits then we 332b8e80941Smrg * know that we have 64-bit types and using them will probably be more 333b8e80941Smrg * efficient than messing around with 32-bit shifts and packing. 334b8e80941Smrg */ 335b8e80941Smrg const unsigned bit_size = MAX2(options->ballot_bit_size, 336b8e80941Smrg intrin->dest.ssa.bit_size); 337b8e80941Smrg 338b8e80941Smrg assert(options->subgroup_size <= 64); 339b8e80941Smrg uint64_t group_mask = ~0ull >> (64 - options->subgroup_size); 340b8e80941Smrg 341b8e80941Smrg nir_ssa_def *count = nir_load_subgroup_invocation(b); 342b8e80941Smrg nir_ssa_def *val; 343b8e80941Smrg switch (intrin->intrinsic) { 344b8e80941Smrg case nir_intrinsic_load_subgroup_eq_mask: 345b8e80941Smrg val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count); 346b8e80941Smrg break; 347b8e80941Smrg case nir_intrinsic_load_subgroup_ge_mask: 348b8e80941Smrg val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count), 349b8e80941Smrg nir_imm_intN_t(b, group_mask, bit_size)); 350b8e80941Smrg break; 351b8e80941Smrg case nir_intrinsic_load_subgroup_gt_mask: 352b8e80941Smrg val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count), 353b8e80941Smrg nir_imm_intN_t(b, group_mask, bit_size)); 354b8e80941Smrg break; 355b8e80941Smrg case nir_intrinsic_load_subgroup_le_mask: 356b8e80941Smrg val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count)); 357b8e80941Smrg break; 358b8e80941Smrg case nir_intrinsic_load_subgroup_lt_mask: 359b8e80941Smrg val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count)); 360b8e80941Smrg break; 361b8e80941Smrg default: 362b8e80941Smrg unreachable("you seriously can't tell this is unreachable?"); 363b8e80941Smrg } 364b8e80941Smrg 365b8e80941Smrg return uint_to_ballot_type(b, val, 366b8e80941Smrg intrin->dest.ssa.num_components, 367b8e80941Smrg intrin->dest.ssa.bit_size); 368b8e80941Smrg } 369b8e80941Smrg 370b8e80941Smrg case nir_intrinsic_ballot: { 371b8e80941Smrg if (intrin->dest.ssa.num_components == 1 && 372b8e80941Smrg intrin->dest.ssa.bit_size == options->ballot_bit_size) 373b8e80941Smrg return NULL; 374b8e80941Smrg 375b8e80941Smrg nir_intrinsic_instr *ballot = 376b8e80941Smrg nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot); 377b8e80941Smrg ballot->num_components = 1; 378b8e80941Smrg nir_ssa_dest_init(&ballot->instr, &ballot->dest, 379b8e80941Smrg 1, options->ballot_bit_size, NULL); 380b8e80941Smrg nir_src_copy(&ballot->src[0], &intrin->src[0], ballot); 381b8e80941Smrg nir_builder_instr_insert(b, &ballot->instr); 382b8e80941Smrg 383b8e80941Smrg return uint_to_ballot_type(b, &ballot->dest.ssa, 384b8e80941Smrg intrin->dest.ssa.num_components, 385b8e80941Smrg intrin->dest.ssa.bit_size); 386b8e80941Smrg } 387b8e80941Smrg 388b8e80941Smrg case nir_intrinsic_ballot_bitfield_extract: 389b8e80941Smrg case nir_intrinsic_ballot_bit_count_reduce: 390b8e80941Smrg case nir_intrinsic_ballot_find_lsb: 391b8e80941Smrg case nir_intrinsic_ballot_find_msb: { 392b8e80941Smrg assert(intrin->src[0].is_ssa); 393b8e80941Smrg nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, 394b8e80941Smrg options->ballot_bit_size); 395b8e80941Smrg switch (intrin->intrinsic) { 396b8e80941Smrg case nir_intrinsic_ballot_bitfield_extract: 397b8e80941Smrg assert(intrin->src[1].is_ssa); 398b8e80941Smrg return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val, 399b8e80941Smrg intrin->src[1].ssa), 400b8e80941Smrg nir_imm_intN_t(b, 1, options->ballot_bit_size))); 401b8e80941Smrg case nir_intrinsic_ballot_bit_count_reduce: 402b8e80941Smrg return nir_bit_count(b, int_val); 403b8e80941Smrg case nir_intrinsic_ballot_find_lsb: 404b8e80941Smrg return nir_find_lsb(b, int_val); 405b8e80941Smrg case nir_intrinsic_ballot_find_msb: 406b8e80941Smrg return nir_ufind_msb(b, int_val); 407b8e80941Smrg default: 408b8e80941Smrg unreachable("you seriously can't tell this is unreachable?"); 409b8e80941Smrg } 410b8e80941Smrg } 411b8e80941Smrg 412b8e80941Smrg case nir_intrinsic_ballot_bit_count_exclusive: 413b8e80941Smrg case nir_intrinsic_ballot_bit_count_inclusive: { 414b8e80941Smrg nir_ssa_def *count = nir_load_subgroup_invocation(b); 415b8e80941Smrg nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size); 416b8e80941Smrg if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) { 417b8e80941Smrg const unsigned bits = options->ballot_bit_size; 418b8e80941Smrg mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count)); 419b8e80941Smrg } else { 420b8e80941Smrg mask = nir_inot(b, nir_ishl(b, mask, count)); 421b8e80941Smrg } 422b8e80941Smrg 423b8e80941Smrg assert(intrin->src[0].is_ssa); 424b8e80941Smrg nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, 425b8e80941Smrg options->ballot_bit_size); 426b8e80941Smrg 427b8e80941Smrg return nir_bit_count(b, nir_iand(b, int_val, mask)); 428b8e80941Smrg } 429b8e80941Smrg 430b8e80941Smrg case nir_intrinsic_elect: { 431b8e80941Smrg nir_intrinsic_instr *first = 432b8e80941Smrg nir_intrinsic_instr_create(b->shader, 433b8e80941Smrg nir_intrinsic_first_invocation); 434b8e80941Smrg nir_ssa_dest_init(&first->instr, &first->dest, 1, 32, NULL); 435b8e80941Smrg nir_builder_instr_insert(b, &first->instr); 436b8e80941Smrg 437b8e80941Smrg return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa); 438b8e80941Smrg } 439b8e80941Smrg 440b8e80941Smrg case nir_intrinsic_shuffle: 441b8e80941Smrg if (options->lower_to_scalar && intrin->num_components > 1) 442b8e80941Smrg return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit); 443b8e80941Smrg else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) 444b8e80941Smrg return lower_subgroup_op_to_32bit(b, intrin); 445b8e80941Smrg break; 446b8e80941Smrg 447b8e80941Smrg case nir_intrinsic_shuffle_xor: 448b8e80941Smrg case nir_intrinsic_shuffle_up: 449b8e80941Smrg case nir_intrinsic_shuffle_down: 450b8e80941Smrg if (options->lower_shuffle) 451b8e80941Smrg return lower_shuffle(b, intrin, options->lower_to_scalar, options->lower_shuffle_to_32bit); 452b8e80941Smrg else if (options->lower_to_scalar && intrin->num_components > 1) 453b8e80941Smrg return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit); 454b8e80941Smrg else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) 455b8e80941Smrg return lower_subgroup_op_to_32bit(b, intrin); 456b8e80941Smrg break; 457b8e80941Smrg 458b8e80941Smrg case nir_intrinsic_quad_broadcast: 459b8e80941Smrg case nir_intrinsic_quad_swap_horizontal: 460b8e80941Smrg case nir_intrinsic_quad_swap_vertical: 461b8e80941Smrg case nir_intrinsic_quad_swap_diagonal: 462b8e80941Smrg if (options->lower_quad) 463b8e80941Smrg return lower_shuffle(b, intrin, options->lower_to_scalar, false); 464b8e80941Smrg else if (options->lower_to_scalar && intrin->num_components > 1) 465b8e80941Smrg return lower_subgroup_op_to_scalar(b, intrin, false); 466b8e80941Smrg break; 467b8e80941Smrg 468b8e80941Smrg case nir_intrinsic_reduce: 469b8e80941Smrg case nir_intrinsic_inclusive_scan: 470b8e80941Smrg case nir_intrinsic_exclusive_scan: 471b8e80941Smrg if (options->lower_to_scalar && intrin->num_components > 1) 472b8e80941Smrg return lower_subgroup_op_to_scalar(b, intrin, false); 473b8e80941Smrg break; 474b8e80941Smrg 475b8e80941Smrg default: 476b8e80941Smrg break; 477b8e80941Smrg } 478b8e80941Smrg 479b8e80941Smrg return NULL; 480b8e80941Smrg} 481b8e80941Smrg 482b8e80941Smrgstatic bool 483b8e80941Smrglower_subgroups_impl(nir_function_impl *impl, 484b8e80941Smrg const nir_lower_subgroups_options *options) 485b8e80941Smrg{ 486b8e80941Smrg nir_builder b; 487b8e80941Smrg nir_builder_init(&b, impl); 488b8e80941Smrg bool progress = false; 489b8e80941Smrg 490b8e80941Smrg nir_foreach_block(block, impl) { 491b8e80941Smrg nir_foreach_instr_safe(instr, block) { 492b8e80941Smrg if (instr->type != nir_instr_type_intrinsic) 493b8e80941Smrg continue; 494b8e80941Smrg 495b8e80941Smrg nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 496b8e80941Smrg b.cursor = nir_before_instr(instr); 497b8e80941Smrg 498b8e80941Smrg nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options); 499b8e80941Smrg if (!lower) 500b8e80941Smrg continue; 501b8e80941Smrg 502b8e80941Smrg nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower)); 503b8e80941Smrg nir_instr_remove(instr); 504b8e80941Smrg progress = true; 505b8e80941Smrg } 506b8e80941Smrg } 507b8e80941Smrg 508b8e80941Smrg return progress; 509b8e80941Smrg} 510b8e80941Smrg 511b8e80941Smrgbool 512b8e80941Smrgnir_lower_subgroups(nir_shader *shader, 513b8e80941Smrg const nir_lower_subgroups_options *options) 514b8e80941Smrg{ 515b8e80941Smrg bool progress = false; 516b8e80941Smrg 517b8e80941Smrg nir_foreach_function(function, shader) { 518b8e80941Smrg if (!function->impl) 519b8e80941Smrg continue; 520b8e80941Smrg 521b8e80941Smrg if (lower_subgroups_impl(function->impl, options)) { 522b8e80941Smrg progress = true; 523b8e80941Smrg nir_metadata_preserve(function->impl, nir_metadata_block_index | 524b8e80941Smrg nir_metadata_dominance); 525b8e80941Smrg } 526b8e80941Smrg } 527b8e80941Smrg 528b8e80941Smrg return progress; 529b8e80941Smrg} 530