nir_search.c revision 01e04c3f
1/* 2 * Copyright © 2014 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 * Authors: 24 * Jason Ekstrand (jason@jlekstrand.net) 25 * 26 */ 27 28#include <inttypes.h> 29#include "nir_search.h" 30#include "nir_builder.h" 31#include "util/half_float.h" 32 33struct match_state { 34 bool inexact_match; 35 bool has_exact_alu; 36 unsigned variables_seen; 37 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES]; 38}; 39 40static bool 41match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 42 unsigned num_components, const uint8_t *swizzle, 43 struct match_state *state); 44 45static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 }; 46 47/** 48 * Check if a source produces a value of the given type. 49 * 50 * Used for satisfying 'a@type' constraints. 51 */ 52static bool 53src_is_type(nir_src src, nir_alu_type type) 54{ 55 assert(type != nir_type_invalid); 56 57 if (!src.is_ssa) 58 return false; 59 60 if (src.ssa->parent_instr->type == nir_instr_type_alu) { 61 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr); 62 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type; 63 64 if (type == nir_type_bool) { 65 switch (src_alu->op) { 66 case nir_op_iand: 67 case nir_op_ior: 68 case nir_op_ixor: 69 return src_is_type(src_alu->src[0].src, nir_type_bool) && 70 src_is_type(src_alu->src[1].src, nir_type_bool); 71 case nir_op_inot: 72 return src_is_type(src_alu->src[0].src, nir_type_bool); 73 default: 74 break; 75 } 76 } 77 78 return nir_alu_type_get_base_type(output_type) == type; 79 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) { 80 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr); 81 82 if (type == nir_type_bool) { 83 return intr->intrinsic == nir_intrinsic_load_front_face || 84 intr->intrinsic == nir_intrinsic_load_helper_invocation; 85 } 86 } 87 88 /* don't know */ 89 return false; 90} 91 92static bool 93match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, 94 unsigned num_components, const uint8_t *swizzle, 95 struct match_state *state) 96{ 97 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS]; 98 99 /* Searching only works on SSA values because, if it's not SSA, we can't 100 * know if the value changed between one instance of that value in the 101 * expression and another. Also, the replace operation will place reads of 102 * that value right before the last instruction in the expression we're 103 * replacing so those reads will happen after the original reads and may 104 * not be valid if they're register reads. 105 */ 106 if (!instr->src[src].src.is_ssa) 107 return false; 108 109 /* If the source is an explicitly sized source, then we need to reset 110 * both the number of components and the swizzle. 111 */ 112 if (nir_op_infos[instr->op].input_sizes[src] != 0) { 113 num_components = nir_op_infos[instr->op].input_sizes[src]; 114 swizzle = identity_swizzle; 115 } 116 117 for (unsigned i = 0; i < num_components; ++i) 118 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; 119 120 /* If the value has a specific bit size and it doesn't match, bail */ 121 if (value->bit_size && 122 nir_src_bit_size(instr->src[src].src) != value->bit_size) 123 return false; 124 125 switch (value->type) { 126 case nir_search_value_expression: 127 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) 128 return false; 129 130 return match_expression(nir_search_value_as_expression(value), 131 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr), 132 num_components, new_swizzle, state); 133 134 case nir_search_value_variable: { 135 nir_search_variable *var = nir_search_value_as_variable(value); 136 assert(var->variable < NIR_SEARCH_MAX_VARIABLES); 137 138 if (state->variables_seen & (1 << var->variable)) { 139 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa) 140 return false; 141 142 assert(!instr->src[src].abs && !instr->src[src].negate); 143 144 for (unsigned i = 0; i < num_components; ++i) { 145 if (state->variables[var->variable].swizzle[i] != new_swizzle[i]) 146 return false; 147 } 148 149 return true; 150 } else { 151 if (var->is_constant && 152 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 153 return false; 154 155 if (var->cond && !var->cond(instr, src, num_components, new_swizzle)) 156 return false; 157 158 if (var->type != nir_type_invalid && 159 !src_is_type(instr->src[src].src, var->type)) 160 return false; 161 162 state->variables_seen |= (1 << var->variable); 163 state->variables[var->variable].src = instr->src[src].src; 164 state->variables[var->variable].abs = false; 165 state->variables[var->variable].negate = false; 166 167 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) { 168 if (i < num_components) 169 state->variables[var->variable].swizzle[i] = new_swizzle[i]; 170 else 171 state->variables[var->variable].swizzle[i] = 0; 172 } 173 174 return true; 175 } 176 } 177 178 case nir_search_value_constant: { 179 nir_search_constant *const_val = nir_search_value_as_constant(value); 180 181 if (!nir_src_is_const(instr->src[src].src)) 182 return false; 183 184 switch (const_val->type) { 185 case nir_type_float: 186 for (unsigned i = 0; i < num_components; ++i) { 187 double val = nir_src_comp_as_float(instr->src[src].src, 188 new_swizzle[i]); 189 if (val != const_val->data.d) 190 return false; 191 } 192 return true; 193 194 case nir_type_int: 195 case nir_type_uint: 196 case nir_type_bool: { 197 unsigned bit_size = nir_src_bit_size(instr->src[src].src); 198 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1; 199 for (unsigned i = 0; i < num_components; ++i) { 200 uint64_t val = nir_src_comp_as_uint(instr->src[src].src, 201 new_swizzle[i]); 202 if ((val & mask) != (const_val->data.u & mask)) 203 return false; 204 } 205 return true; 206 } 207 208 default: 209 unreachable("Invalid alu source type"); 210 } 211 } 212 213 default: 214 unreachable("Invalid search value type"); 215 } 216} 217 218static bool 219match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 220 unsigned num_components, const uint8_t *swizzle, 221 struct match_state *state) 222{ 223 if (expr->cond && !expr->cond(instr)) 224 return false; 225 226 if (instr->op != expr->opcode) 227 return false; 228 229 assert(instr->dest.dest.is_ssa); 230 231 if (expr->value.bit_size && 232 instr->dest.dest.ssa.bit_size != expr->value.bit_size) 233 return false; 234 235 state->inexact_match = expr->inexact || state->inexact_match; 236 state->has_exact_alu = instr->exact || state->has_exact_alu; 237 if (state->inexact_match && state->has_exact_alu) 238 return false; 239 240 assert(!instr->dest.saturate); 241 assert(nir_op_infos[instr->op].num_inputs > 0); 242 243 /* If we have an explicitly sized destination, we can only handle the 244 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid 245 * expression, we don't have the information right now to propagate that 246 * swizzle through. We can only properly propagate swizzles if the 247 * instruction is vectorized. 248 */ 249 if (nir_op_infos[instr->op].output_size != 0) { 250 for (unsigned i = 0; i < num_components; i++) { 251 if (swizzle[i] != i) 252 return false; 253 } 254 } 255 256 /* Stash off the current variables_seen bitmask. This way we can 257 * restore it prior to matching in the commutative case below. 258 */ 259 unsigned variables_seen_stash = state->variables_seen; 260 261 bool matched = true; 262 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { 263 if (!match_value(expr->srcs[i], instr, i, num_components, 264 swizzle, state)) { 265 matched = false; 266 break; 267 } 268 } 269 270 if (matched) 271 return true; 272 273 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) { 274 assert(nir_op_infos[instr->op].num_inputs == 2); 275 276 /* Restore the variables_seen bitmask. If we don't do this, then we 277 * could end up with an erroneous failure due to variables found in the 278 * first match attempt above not matching those in the second. 279 */ 280 state->variables_seen = variables_seen_stash; 281 282 if (!match_value(expr->srcs[0], instr, 1, num_components, 283 swizzle, state)) 284 return false; 285 286 return match_value(expr->srcs[1], instr, 0, num_components, 287 swizzle, state); 288 } else { 289 return false; 290 } 291} 292 293typedef struct bitsize_tree { 294 unsigned num_srcs; 295 struct bitsize_tree *srcs[4]; 296 297 unsigned common_size; 298 bool is_src_sized[4]; 299 bool is_dest_sized; 300 301 unsigned dest_size; 302 unsigned src_size[4]; 303} bitsize_tree; 304 305static bitsize_tree * 306build_bitsize_tree(void *mem_ctx, struct match_state *state, 307 const nir_search_value *value) 308{ 309 bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree); 310 311 switch (value->type) { 312 case nir_search_value_expression: { 313 nir_search_expression *expr = nir_search_value_as_expression(value); 314 nir_op_info info = nir_op_infos[expr->opcode]; 315 tree->num_srcs = info.num_inputs; 316 tree->common_size = 0; 317 for (unsigned i = 0; i < info.num_inputs; i++) { 318 tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]); 319 if (tree->is_src_sized[i]) 320 tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]); 321 tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]); 322 } 323 tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type); 324 if (tree->is_dest_sized) 325 tree->dest_size = nir_alu_type_get_type_size(info.output_type); 326 break; 327 } 328 329 case nir_search_value_variable: { 330 nir_search_variable *var = nir_search_value_as_variable(value); 331 tree->num_srcs = 0; 332 tree->is_dest_sized = true; 333 tree->dest_size = nir_src_bit_size(state->variables[var->variable].src); 334 break; 335 } 336 337 case nir_search_value_constant: { 338 tree->num_srcs = 0; 339 tree->is_dest_sized = false; 340 tree->common_size = 0; 341 break; 342 } 343 } 344 345 if (value->bit_size) { 346 assert(!tree->is_dest_sized || tree->dest_size == value->bit_size); 347 tree->common_size = value->bit_size; 348 } 349 350 return tree; 351} 352 353static unsigned 354bitsize_tree_filter_up(bitsize_tree *tree) 355{ 356 for (unsigned i = 0; i < tree->num_srcs; i++) { 357 unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]); 358 if (src_size == 0) 359 continue; 360 361 if (tree->is_src_sized[i]) { 362 assert(src_size == tree->src_size[i]); 363 } else if (tree->common_size != 0) { 364 assert(src_size == tree->common_size); 365 tree->src_size[i] = src_size; 366 } else { 367 tree->common_size = src_size; 368 tree->src_size[i] = src_size; 369 } 370 } 371 372 if (tree->num_srcs && tree->common_size) { 373 if (tree->dest_size == 0) 374 tree->dest_size = tree->common_size; 375 else if (!tree->is_dest_sized) 376 assert(tree->dest_size == tree->common_size); 377 378 for (unsigned i = 0; i < tree->num_srcs; i++) { 379 if (!tree->src_size[i]) 380 tree->src_size[i] = tree->common_size; 381 } 382 } 383 384 return tree->dest_size; 385} 386 387static void 388bitsize_tree_filter_down(bitsize_tree *tree, unsigned size) 389{ 390 if (tree->dest_size) 391 assert(tree->dest_size == size); 392 else 393 tree->dest_size = size; 394 395 if (!tree->is_dest_sized) { 396 if (tree->common_size) 397 assert(tree->common_size == size); 398 else 399 tree->common_size = size; 400 } 401 402 for (unsigned i = 0; i < tree->num_srcs; i++) { 403 if (!tree->src_size[i]) { 404 assert(tree->common_size); 405 tree->src_size[i] = tree->common_size; 406 } 407 bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]); 408 } 409} 410 411static nir_alu_src 412construct_value(nir_builder *build, 413 const nir_search_value *value, 414 unsigned num_components, bitsize_tree *bitsize, 415 struct match_state *state, 416 nir_instr *instr) 417{ 418 switch (value->type) { 419 case nir_search_value_expression: { 420 const nir_search_expression *expr = nir_search_value_as_expression(value); 421 422 if (nir_op_infos[expr->opcode].output_size != 0) 423 num_components = nir_op_infos[expr->opcode].output_size; 424 425 nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode); 426 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, 427 bitsize->dest_size, NULL); 428 alu->dest.write_mask = (1 << num_components) - 1; 429 alu->dest.saturate = false; 430 431 /* We have no way of knowing what values in a given search expression 432 * map to a particular replacement value. Therefore, if the 433 * expression we are replacing has any exact values, the entire 434 * replacement should be exact. 435 */ 436 alu->exact = state->has_exact_alu; 437 438 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) { 439 /* If the source is an explicitly sized source, then we need to reset 440 * the number of components to match. 441 */ 442 if (nir_op_infos[alu->op].input_sizes[i] != 0) 443 num_components = nir_op_infos[alu->op].input_sizes[i]; 444 445 alu->src[i] = construct_value(build, expr->srcs[i], 446 num_components, bitsize->srcs[i], 447 state, instr); 448 } 449 450 nir_builder_instr_insert(build, &alu->instr); 451 452 nir_alu_src val; 453 val.src = nir_src_for_ssa(&alu->dest.dest.ssa); 454 val.negate = false; 455 val.abs = false, 456 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle); 457 458 return val; 459 } 460 461 case nir_search_value_variable: { 462 const nir_search_variable *var = nir_search_value_as_variable(value); 463 assert(state->variables_seen & (1 << var->variable)); 464 465 nir_alu_src val = { NIR_SRC_INIT }; 466 nir_alu_src_copy(&val, &state->variables[var->variable], 467 (void *)build->shader); 468 assert(!var->is_constant); 469 470 return val; 471 } 472 473 case nir_search_value_constant: { 474 const nir_search_constant *c = nir_search_value_as_constant(value); 475 476 nir_ssa_def *cval; 477 switch (c->type) { 478 case nir_type_float: 479 cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size); 480 break; 481 482 case nir_type_int: 483 case nir_type_uint: 484 cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size); 485 break; 486 487 case nir_type_bool: 488 cval = nir_imm_bool(build, c->data.u); 489 break; 490 default: 491 unreachable("Invalid alu source type"); 492 } 493 494 nir_alu_src val; 495 val.src = nir_src_for_ssa(cval); 496 val.negate = false; 497 val.abs = false, 498 memset(val.swizzle, 0, sizeof val.swizzle); 499 500 return val; 501 } 502 503 default: 504 unreachable("Invalid search value type"); 505 } 506} 507 508nir_ssa_def * 509nir_replace_instr(nir_builder *build, nir_alu_instr *instr, 510 const nir_search_expression *search, 511 const nir_search_value *replace) 512{ 513 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 }; 514 515 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i) 516 swizzle[i] = i; 517 518 assert(instr->dest.dest.is_ssa); 519 520 struct match_state state; 521 state.inexact_match = false; 522 state.has_exact_alu = false; 523 state.variables_seen = 0; 524 525 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components, 526 swizzle, &state)) 527 return NULL; 528 529 void *bitsize_ctx = ralloc_context(NULL); 530 bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace); 531 bitsize_tree_filter_up(tree); 532 bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size); 533 534 build->cursor = nir_before_instr(&instr->instr); 535 536 nir_alu_src val = construct_value(build, replace, 537 instr->dest.dest.ssa.num_components, 538 tree, &state, &instr->instr); 539 540 /* Inserting a mov may be unnecessary. However, it's much easier to 541 * simply let copy propagation clean this up than to try to go through 542 * and rewrite swizzles ourselves. 543 */ 544 nir_ssa_def *ssa_val = 545 nir_imov_alu(build, val, instr->dest.dest.ssa.num_components); 546 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val)); 547 548 /* We know this one has no more uses because we just rewrote them all, 549 * so we can remove it. The rest of the matched expression, however, we 550 * don't know so much about. We'll just let dead code clean them up. 551 */ 552 nir_instr_remove(&instr->instr); 553 554 ralloc_free(bitsize_ctx); 555 556 return ssa_val; 557} 558