1/* 2 * Copyright © 2014 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 * Authors: 24 * Jason Ekstrand (jason@jlekstrand.net) 25 * 26 */ 27 28#include <inttypes.h> 29#include "nir_search.h" 30#include "nir_builder.h" 31#include "util/half_float.h" 32 33#define NIR_SEARCH_MAX_COMM_OPS 4 34 35struct match_state { 36 bool inexact_match; 37 bool has_exact_alu; 38 uint8_t comm_op_direction; 39 unsigned variables_seen; 40 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES]; 41}; 42 43static bool 44match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 45 unsigned num_components, const uint8_t *swizzle, 46 struct match_state *state); 47 48static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 }; 49 50/** 51 * Check if a source produces a value of the given type. 52 * 53 * Used for satisfying 'a@type' constraints. 54 */ 55static bool 56src_is_type(nir_src src, nir_alu_type type) 57{ 58 assert(type != nir_type_invalid); 59 60 if (!src.is_ssa) 61 return false; 62 63 if (src.ssa->parent_instr->type == nir_instr_type_alu) { 64 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr); 65 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type; 66 67 if (type == nir_type_bool) { 68 switch (src_alu->op) { 69 case nir_op_iand: 70 case nir_op_ior: 71 case nir_op_ixor: 72 return src_is_type(src_alu->src[0].src, nir_type_bool) && 73 src_is_type(src_alu->src[1].src, nir_type_bool); 74 case nir_op_inot: 75 return src_is_type(src_alu->src[0].src, nir_type_bool); 76 default: 77 break; 78 } 79 } 80 81 return nir_alu_type_get_base_type(output_type) == type; 82 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) { 83 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr); 84 85 if (type == nir_type_bool) { 86 return intr->intrinsic == nir_intrinsic_load_front_face || 87 intr->intrinsic == nir_intrinsic_load_helper_invocation; 88 } 89 } 90 91 /* don't know */ 92 return false; 93} 94 95static bool 96nir_op_matches_search_op(nir_op nop, uint16_t sop) 97{ 98 if (sop <= nir_last_opcode) 99 return nop == sop; 100 101#define MATCH_FCONV_CASE(op) \ 102 case nir_search_op_##op: \ 103 return nop == nir_op_##op##16 || \ 104 nop == nir_op_##op##32 || \ 105 nop == nir_op_##op##64; 106 107#define MATCH_ICONV_CASE(op) \ 108 case nir_search_op_##op: \ 109 return nop == nir_op_##op##8 || \ 110 nop == nir_op_##op##16 || \ 111 nop == nir_op_##op##32 || \ 112 nop == nir_op_##op##64; 113 114#define MATCH_BCONV_CASE(op) \ 115 case nir_search_op_##op: \ 116 return nop == nir_op_##op##1 || \ 117 nop == nir_op_##op##32; 118 119 switch (sop) { 120 MATCH_FCONV_CASE(i2f) 121 MATCH_FCONV_CASE(u2f) 122 MATCH_FCONV_CASE(f2f) 123 MATCH_ICONV_CASE(f2u) 124 MATCH_ICONV_CASE(f2i) 125 MATCH_ICONV_CASE(u2u) 126 MATCH_ICONV_CASE(i2i) 127 MATCH_FCONV_CASE(b2f) 128 MATCH_ICONV_CASE(b2i) 129 MATCH_BCONV_CASE(i2b) 130 MATCH_BCONV_CASE(f2b) 131 default: 132 unreachable("Invalid nir_search_op"); 133 } 134 135#undef MATCH_FCONV_CASE 136#undef MATCH_ICONV_CASE 137#undef MATCH_BCONV_CASE 138} 139 140uint16_t 141nir_search_op_for_nir_op(nir_op nop) 142{ 143#define MATCH_FCONV_CASE(op) \ 144 case nir_op_##op##16: \ 145 case nir_op_##op##32: \ 146 case nir_op_##op##64: \ 147 return nir_search_op_##op; 148 149#define MATCH_ICONV_CASE(op) \ 150 case nir_op_##op##8: \ 151 case nir_op_##op##16: \ 152 case nir_op_##op##32: \ 153 case nir_op_##op##64: \ 154 return nir_search_op_##op; 155 156#define MATCH_BCONV_CASE(op) \ 157 case nir_op_##op##1: \ 158 case nir_op_##op##32: \ 159 return nir_search_op_##op; 160 161 162 switch (nop) { 163 MATCH_FCONV_CASE(i2f) 164 MATCH_FCONV_CASE(u2f) 165 MATCH_FCONV_CASE(f2f) 166 MATCH_ICONV_CASE(f2u) 167 MATCH_ICONV_CASE(f2i) 168 MATCH_ICONV_CASE(u2u) 169 MATCH_ICONV_CASE(i2i) 170 MATCH_FCONV_CASE(b2f) 171 MATCH_ICONV_CASE(b2i) 172 MATCH_BCONV_CASE(i2b) 173 MATCH_BCONV_CASE(f2b) 174 default: 175 return nop; 176 } 177 178#undef MATCH_FCONV_CASE 179#undef MATCH_ICONV_CASE 180#undef MATCH_BCONV_CASE 181} 182 183static nir_op 184nir_op_for_search_op(uint16_t sop, unsigned bit_size) 185{ 186 if (sop <= nir_last_opcode) 187 return sop; 188 189#define RET_FCONV_CASE(op) \ 190 case nir_search_op_##op: \ 191 switch (bit_size) { \ 192 case 16: return nir_op_##op##16; \ 193 case 32: return nir_op_##op##32; \ 194 case 64: return nir_op_##op##64; \ 195 default: unreachable("Invalid bit size"); \ 196 } 197 198#define RET_ICONV_CASE(op) \ 199 case nir_search_op_##op: \ 200 switch (bit_size) { \ 201 case 8: return nir_op_##op##8; \ 202 case 16: return nir_op_##op##16; \ 203 case 32: return nir_op_##op##32; \ 204 case 64: return nir_op_##op##64; \ 205 default: unreachable("Invalid bit size"); \ 206 } 207 208#define RET_BCONV_CASE(op) \ 209 case nir_search_op_##op: \ 210 switch (bit_size) { \ 211 case 1: return nir_op_##op##1; \ 212 case 32: return nir_op_##op##32; \ 213 default: unreachable("Invalid bit size"); \ 214 } 215 216 switch (sop) { 217 RET_FCONV_CASE(i2f) 218 RET_FCONV_CASE(u2f) 219 RET_FCONV_CASE(f2f) 220 RET_ICONV_CASE(f2u) 221 RET_ICONV_CASE(f2i) 222 RET_ICONV_CASE(u2u) 223 RET_ICONV_CASE(i2i) 224 RET_FCONV_CASE(b2f) 225 RET_ICONV_CASE(b2i) 226 RET_BCONV_CASE(i2b) 227 RET_BCONV_CASE(f2b) 228 default: 229 unreachable("Invalid nir_search_op"); 230 } 231 232#undef RET_FCONV_CASE 233#undef RET_ICONV_CASE 234#undef RET_BCONV_CASE 235} 236 237static bool 238match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, 239 unsigned num_components, const uint8_t *swizzle, 240 struct match_state *state) 241{ 242 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS]; 243 244 /* Searching only works on SSA values because, if it's not SSA, we can't 245 * know if the value changed between one instance of that value in the 246 * expression and another. Also, the replace operation will place reads of 247 * that value right before the last instruction in the expression we're 248 * replacing so those reads will happen after the original reads and may 249 * not be valid if they're register reads. 250 */ 251 assert(instr->src[src].src.is_ssa); 252 253 /* If the source is an explicitly sized source, then we need to reset 254 * both the number of components and the swizzle. 255 */ 256 if (nir_op_infos[instr->op].input_sizes[src] != 0) { 257 num_components = nir_op_infos[instr->op].input_sizes[src]; 258 swizzle = identity_swizzle; 259 } 260 261 for (unsigned i = 0; i < num_components; ++i) 262 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; 263 264 /* If the value has a specific bit size and it doesn't match, bail */ 265 if (value->bit_size > 0 && 266 nir_src_bit_size(instr->src[src].src) != value->bit_size) 267 return false; 268 269 switch (value->type) { 270 case nir_search_value_expression: 271 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) 272 return false; 273 274 return match_expression(nir_search_value_as_expression(value), 275 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr), 276 num_components, new_swizzle, state); 277 278 case nir_search_value_variable: { 279 nir_search_variable *var = nir_search_value_as_variable(value); 280 assert(var->variable < NIR_SEARCH_MAX_VARIABLES); 281 282 if (state->variables_seen & (1 << var->variable)) { 283 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa) 284 return false; 285 286 assert(!instr->src[src].abs && !instr->src[src].negate); 287 288 for (unsigned i = 0; i < num_components; ++i) { 289 if (state->variables[var->variable].swizzle[i] != new_swizzle[i]) 290 return false; 291 } 292 293 return true; 294 } else { 295 if (var->is_constant && 296 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 297 return false; 298 299 if (var->cond && !var->cond(instr, src, num_components, new_swizzle)) 300 return false; 301 302 if (var->type != nir_type_invalid && 303 !src_is_type(instr->src[src].src, var->type)) 304 return false; 305 306 state->variables_seen |= (1 << var->variable); 307 state->variables[var->variable].src = instr->src[src].src; 308 state->variables[var->variable].abs = false; 309 state->variables[var->variable].negate = false; 310 311 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) { 312 if (i < num_components) 313 state->variables[var->variable].swizzle[i] = new_swizzle[i]; 314 else 315 state->variables[var->variable].swizzle[i] = 0; 316 } 317 318 return true; 319 } 320 } 321 322 case nir_search_value_constant: { 323 nir_search_constant *const_val = nir_search_value_as_constant(value); 324 325 if (!nir_src_is_const(instr->src[src].src)) 326 return false; 327 328 switch (const_val->type) { 329 case nir_type_float: 330 for (unsigned i = 0; i < num_components; ++i) { 331 double val = nir_src_comp_as_float(instr->src[src].src, 332 new_swizzle[i]); 333 if (val != const_val->data.d) 334 return false; 335 } 336 return true; 337 338 case nir_type_int: 339 case nir_type_uint: 340 case nir_type_bool: { 341 unsigned bit_size = nir_src_bit_size(instr->src[src].src); 342 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1; 343 for (unsigned i = 0; i < num_components; ++i) { 344 uint64_t val = nir_src_comp_as_uint(instr->src[src].src, 345 new_swizzle[i]); 346 if ((val & mask) != (const_val->data.u & mask)) 347 return false; 348 } 349 return true; 350 } 351 352 default: 353 unreachable("Invalid alu source type"); 354 } 355 } 356 357 default: 358 unreachable("Invalid search value type"); 359 } 360} 361 362static bool 363match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 364 unsigned num_components, const uint8_t *swizzle, 365 struct match_state *state) 366{ 367 if (expr->cond && !expr->cond(instr)) 368 return false; 369 370 if (!nir_op_matches_search_op(instr->op, expr->opcode)) 371 return false; 372 373 assert(instr->dest.dest.is_ssa); 374 375 if (expr->value.bit_size > 0 && 376 instr->dest.dest.ssa.bit_size != expr->value.bit_size) 377 return false; 378 379 state->inexact_match = expr->inexact || state->inexact_match; 380 state->has_exact_alu = instr->exact || state->has_exact_alu; 381 if (state->inexact_match && state->has_exact_alu) 382 return false; 383 384 assert(!instr->dest.saturate); 385 assert(nir_op_infos[instr->op].num_inputs > 0); 386 387 /* If we have an explicitly sized destination, we can only handle the 388 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid 389 * expression, we don't have the information right now to propagate that 390 * swizzle through. We can only properly propagate swizzles if the 391 * instruction is vectorized. 392 */ 393 if (nir_op_infos[instr->op].output_size != 0) { 394 for (unsigned i = 0; i < num_components; i++) { 395 if (swizzle[i] != i) 396 return false; 397 } 398 } 399 400 /* If this is a commutative expression and it's one of the first few, look 401 * up its direction for the current search operation. We'll use that value 402 * to possibly flip the sources for the match. 403 */ 404 unsigned comm_op_flip = 405 (expr->comm_expr_idx >= 0 && 406 expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ? 407 ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0; 408 409 bool matched = true; 410 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { 411 if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip, 412 num_components, swizzle, state)) { 413 matched = false; 414 break; 415 } 416 } 417 418 return matched; 419} 420 421static unsigned 422replace_bitsize(const nir_search_value *value, unsigned search_bitsize, 423 struct match_state *state) 424{ 425 if (value->bit_size > 0) 426 return value->bit_size; 427 if (value->bit_size < 0) 428 return nir_src_bit_size(state->variables[-value->bit_size - 1].src); 429 return search_bitsize; 430} 431 432static nir_alu_src 433construct_value(nir_builder *build, 434 const nir_search_value *value, 435 unsigned num_components, unsigned search_bitsize, 436 struct match_state *state, 437 nir_instr *instr) 438{ 439 switch (value->type) { 440 case nir_search_value_expression: { 441 const nir_search_expression *expr = nir_search_value_as_expression(value); 442 unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state); 443 nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size); 444 445 if (nir_op_infos[op].output_size != 0) 446 num_components = nir_op_infos[op].output_size; 447 448 nir_alu_instr *alu = nir_alu_instr_create(build->shader, op); 449 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, 450 dst_bit_size, NULL); 451 alu->dest.write_mask = (1 << num_components) - 1; 452 alu->dest.saturate = false; 453 454 /* We have no way of knowing what values in a given search expression 455 * map to a particular replacement value. Therefore, if the 456 * expression we are replacing has any exact values, the entire 457 * replacement should be exact. 458 */ 459 alu->exact = state->has_exact_alu; 460 461 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { 462 /* If the source is an explicitly sized source, then we need to reset 463 * the number of components to match. 464 */ 465 if (nir_op_infos[alu->op].input_sizes[i] != 0) 466 num_components = nir_op_infos[alu->op].input_sizes[i]; 467 468 alu->src[i] = construct_value(build, expr->srcs[i], 469 num_components, search_bitsize, 470 state, instr); 471 } 472 473 nir_builder_instr_insert(build, &alu->instr); 474 475 nir_alu_src val; 476 val.src = nir_src_for_ssa(&alu->dest.dest.ssa); 477 val.negate = false; 478 val.abs = false, 479 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle); 480 481 return val; 482 } 483 484 case nir_search_value_variable: { 485 const nir_search_variable *var = nir_search_value_as_variable(value); 486 assert(state->variables_seen & (1 << var->variable)); 487 488 nir_alu_src val = { NIR_SRC_INIT }; 489 nir_alu_src_copy(&val, &state->variables[var->variable], 490 (void *)build->shader); 491 assert(!var->is_constant); 492 493 return val; 494 } 495 496 case nir_search_value_constant: { 497 const nir_search_constant *c = nir_search_value_as_constant(value); 498 unsigned bit_size = replace_bitsize(value, search_bitsize, state); 499 500 nir_ssa_def *cval; 501 switch (c->type) { 502 case nir_type_float: 503 cval = nir_imm_floatN_t(build, c->data.d, bit_size); 504 break; 505 506 case nir_type_int: 507 case nir_type_uint: 508 cval = nir_imm_intN_t(build, c->data.i, bit_size); 509 break; 510 511 case nir_type_bool: 512 cval = nir_imm_boolN_t(build, c->data.u, bit_size); 513 break; 514 515 default: 516 unreachable("Invalid alu source type"); 517 } 518 519 nir_alu_src val; 520 val.src = nir_src_for_ssa(cval); 521 val.negate = false; 522 val.abs = false, 523 memset(val.swizzle, 0, sizeof val.swizzle); 524 525 return val; 526 } 527 528 default: 529 unreachable("Invalid search value type"); 530 } 531} 532 533MAYBE_UNUSED static void dump_value(const nir_search_value *val) 534{ 535 switch (val->type) { 536 case nir_search_value_constant: { 537 const nir_search_constant *sconst = nir_search_value_as_constant(val); 538 switch (sconst->type) { 539 case nir_type_float: 540 printf("%f", sconst->data.d); 541 break; 542 case nir_type_int: 543 printf("%"PRId64, sconst->data.i); 544 break; 545 case nir_type_uint: 546 printf("0x%"PRIx64, sconst->data.u); 547 break; 548 default: 549 unreachable("bad const type"); 550 } 551 break; 552 } 553 554 case nir_search_value_variable: { 555 const nir_search_variable *var = nir_search_value_as_variable(val); 556 if (var->is_constant) 557 printf("#"); 558 printf("%c", var->variable + 'a'); 559 break; 560 } 561 562 case nir_search_value_expression: { 563 const nir_search_expression *expr = nir_search_value_as_expression(val); 564 printf("("); 565 if (expr->inexact) 566 printf("~"); 567 switch (expr->opcode) { 568#define CASE(n) \ 569 case nir_search_op_##n: printf(#n); break; 570 CASE(f2b) 571 CASE(b2f) 572 CASE(b2i) 573 CASE(i2b) 574 CASE(i2i) 575 CASE(f2i) 576 CASE(i2f) 577#undef CASE 578 default: 579 printf("%s", nir_op_infos[expr->opcode].name); 580 } 581 582 unsigned num_srcs = 1; 583 if (expr->opcode <= nir_last_opcode) 584 num_srcs = nir_op_infos[expr->opcode].num_inputs; 585 586 for (unsigned i = 0; i < num_srcs; i++) { 587 printf(" "); 588 dump_value(expr->srcs[i]); 589 } 590 591 printf(")"); 592 break; 593 } 594 } 595 596 if (val->bit_size > 0) 597 printf("@%d", val->bit_size); 598} 599 600nir_ssa_def * 601nir_replace_instr(nir_builder *build, nir_alu_instr *instr, 602 const nir_search_expression *search, 603 const nir_search_value *replace) 604{ 605 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 }; 606 607 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i) 608 swizzle[i] = i; 609 610 assert(instr->dest.dest.is_ssa); 611 612 struct match_state state; 613 state.inexact_match = false; 614 state.has_exact_alu = false; 615 616 unsigned comm_expr_combinations = 617 1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS); 618 619 bool found = false; 620 for (unsigned comb = 0; comb < comm_expr_combinations; comb++) { 621 /* The bitfield of directions is just the current iteration. Hooray for 622 * binary. 623 */ 624 state.comm_op_direction = comb; 625 state.variables_seen = 0; 626 627 if (match_expression(search, instr, 628 instr->dest.dest.ssa.num_components, 629 swizzle, &state)) { 630 found = true; 631 break; 632 } 633 } 634 if (!found) 635 return NULL; 636 637#if 0 638 printf("matched: "); 639 dump_value(&search->value); 640 printf(" -> "); 641 dump_value(replace); 642 printf(" ssa_%d\n", instr->dest.dest.ssa.index); 643#endif 644 645 build->cursor = nir_before_instr(&instr->instr); 646 647 nir_alu_src val = construct_value(build, replace, 648 instr->dest.dest.ssa.num_components, 649 instr->dest.dest.ssa.bit_size, 650 &state, &instr->instr); 651 652 /* Inserting a mov may be unnecessary. However, it's much easier to 653 * simply let copy propagation clean this up than to try to go through 654 * and rewrite swizzles ourselves. 655 */ 656 nir_ssa_def *ssa_val = 657 nir_imov_alu(build, val, instr->dest.dest.ssa.num_components); 658 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val)); 659 660 /* We know this one has no more uses because we just rewrote them all, 661 * so we can remove it. The rest of the matched expression, however, we 662 * don't know so much about. We'll just let dead code clean them up. 663 */ 664 nir_instr_remove(&instr->instr); 665 666 return ssa_val; 667} 668