1b8e80941Smrg/*
2b8e80941Smrg * Copyright © 2014 Intel Corporation
3b8e80941Smrg *
4b8e80941Smrg * Permission is hereby granted, free of charge, to any person obtaining a
5b8e80941Smrg * copy of this software and associated documentation files (the "Software"),
6b8e80941Smrg * to deal in the Software without restriction, including without limitation
7b8e80941Smrg * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8b8e80941Smrg * and/or sell copies of the Software, and to permit persons to whom the
9b8e80941Smrg * Software is furnished to do so, subject to the following conditions:
10b8e80941Smrg *
11b8e80941Smrg * The above copyright notice and this permission notice (including the next
12b8e80941Smrg * paragraph) shall be included in all copies or substantial portions of the
13b8e80941Smrg * Software.
14b8e80941Smrg *
15b8e80941Smrg * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16b8e80941Smrg * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17b8e80941Smrg * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18b8e80941Smrg * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19b8e80941Smrg * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20b8e80941Smrg * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21b8e80941Smrg * IN THE SOFTWARE.
22b8e80941Smrg *
23b8e80941Smrg * Authors:
24b8e80941Smrg *    Jason Ekstrand (jason@jlekstrand.net)
25b8e80941Smrg *
26b8e80941Smrg */
27b8e80941Smrg
28b8e80941Smrg#include <inttypes.h>
29b8e80941Smrg#include "nir_search.h"
30b8e80941Smrg#include "nir_builder.h"
31b8e80941Smrg#include "util/half_float.h"
32b8e80941Smrg
33b8e80941Smrg#define NIR_SEARCH_MAX_COMM_OPS 4
34b8e80941Smrg
35b8e80941Smrgstruct match_state {
36b8e80941Smrg   bool inexact_match;
37b8e80941Smrg   bool has_exact_alu;
38b8e80941Smrg   uint8_t comm_op_direction;
39b8e80941Smrg   unsigned variables_seen;
40b8e80941Smrg   nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
41b8e80941Smrg};
42b8e80941Smrg
43b8e80941Smrgstatic bool
44b8e80941Smrgmatch_expression(const nir_search_expression *expr, nir_alu_instr *instr,
45b8e80941Smrg                 unsigned num_components, const uint8_t *swizzle,
46b8e80941Smrg                 struct match_state *state);
47b8e80941Smrg
48b8e80941Smrgstatic const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
49b8e80941Smrg
50b8e80941Smrg/**
51b8e80941Smrg * Check if a source produces a value of the given type.
52b8e80941Smrg *
53b8e80941Smrg * Used for satisfying 'a@type' constraints.
54b8e80941Smrg */
55b8e80941Smrgstatic bool
56b8e80941Smrgsrc_is_type(nir_src src, nir_alu_type type)
57b8e80941Smrg{
58b8e80941Smrg   assert(type != nir_type_invalid);
59b8e80941Smrg
60b8e80941Smrg   if (!src.is_ssa)
61b8e80941Smrg      return false;
62b8e80941Smrg
63b8e80941Smrg   if (src.ssa->parent_instr->type == nir_instr_type_alu) {
64b8e80941Smrg      nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
65b8e80941Smrg      nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
66b8e80941Smrg
67b8e80941Smrg      if (type == nir_type_bool) {
68b8e80941Smrg         switch (src_alu->op) {
69b8e80941Smrg         case nir_op_iand:
70b8e80941Smrg         case nir_op_ior:
71b8e80941Smrg         case nir_op_ixor:
72b8e80941Smrg            return src_is_type(src_alu->src[0].src, nir_type_bool) &&
73b8e80941Smrg                   src_is_type(src_alu->src[1].src, nir_type_bool);
74b8e80941Smrg         case nir_op_inot:
75b8e80941Smrg            return src_is_type(src_alu->src[0].src, nir_type_bool);
76b8e80941Smrg         default:
77b8e80941Smrg            break;
78b8e80941Smrg         }
79b8e80941Smrg      }
80b8e80941Smrg
81b8e80941Smrg      return nir_alu_type_get_base_type(output_type) == type;
82b8e80941Smrg   } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
83b8e80941Smrg      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
84b8e80941Smrg
85b8e80941Smrg      if (type == nir_type_bool) {
86b8e80941Smrg         return intr->intrinsic == nir_intrinsic_load_front_face ||
87b8e80941Smrg                intr->intrinsic == nir_intrinsic_load_helper_invocation;
88b8e80941Smrg      }
89b8e80941Smrg   }
90b8e80941Smrg
91b8e80941Smrg   /* don't know */
92b8e80941Smrg   return false;
93b8e80941Smrg}
94b8e80941Smrg
95b8e80941Smrgstatic bool
96b8e80941Smrgnir_op_matches_search_op(nir_op nop, uint16_t sop)
97b8e80941Smrg{
98b8e80941Smrg   if (sop <= nir_last_opcode)
99b8e80941Smrg      return nop == sop;
100b8e80941Smrg
101b8e80941Smrg#define MATCH_FCONV_CASE(op) \
102b8e80941Smrg   case nir_search_op_##op: \
103b8e80941Smrg      return nop == nir_op_##op##16 || \
104b8e80941Smrg             nop == nir_op_##op##32 || \
105b8e80941Smrg             nop == nir_op_##op##64;
106b8e80941Smrg
107b8e80941Smrg#define MATCH_ICONV_CASE(op) \
108b8e80941Smrg   case nir_search_op_##op: \
109b8e80941Smrg      return nop == nir_op_##op##8 || \
110b8e80941Smrg             nop == nir_op_##op##16 || \
111b8e80941Smrg             nop == nir_op_##op##32 || \
112b8e80941Smrg             nop == nir_op_##op##64;
113b8e80941Smrg
114b8e80941Smrg#define MATCH_BCONV_CASE(op) \
115b8e80941Smrg   case nir_search_op_##op: \
116b8e80941Smrg      return nop == nir_op_##op##1 || \
117b8e80941Smrg             nop == nir_op_##op##32;
118b8e80941Smrg
119b8e80941Smrg   switch (sop) {
120b8e80941Smrg   MATCH_FCONV_CASE(i2f)
121b8e80941Smrg   MATCH_FCONV_CASE(u2f)
122b8e80941Smrg   MATCH_FCONV_CASE(f2f)
123b8e80941Smrg   MATCH_ICONV_CASE(f2u)
124b8e80941Smrg   MATCH_ICONV_CASE(f2i)
125b8e80941Smrg   MATCH_ICONV_CASE(u2u)
126b8e80941Smrg   MATCH_ICONV_CASE(i2i)
127b8e80941Smrg   MATCH_FCONV_CASE(b2f)
128b8e80941Smrg   MATCH_ICONV_CASE(b2i)
129b8e80941Smrg   MATCH_BCONV_CASE(i2b)
130b8e80941Smrg   MATCH_BCONV_CASE(f2b)
131b8e80941Smrg   default:
132b8e80941Smrg      unreachable("Invalid nir_search_op");
133b8e80941Smrg   }
134b8e80941Smrg
135b8e80941Smrg#undef MATCH_FCONV_CASE
136b8e80941Smrg#undef MATCH_ICONV_CASE
137b8e80941Smrg#undef MATCH_BCONV_CASE
138b8e80941Smrg}
139b8e80941Smrg
140b8e80941Smrguint16_t
141b8e80941Smrgnir_search_op_for_nir_op(nir_op nop)
142b8e80941Smrg{
143b8e80941Smrg#define MATCH_FCONV_CASE(op) \
144b8e80941Smrg   case nir_op_##op##16: \
145b8e80941Smrg   case nir_op_##op##32: \
146b8e80941Smrg   case nir_op_##op##64: \
147b8e80941Smrg      return nir_search_op_##op;
148b8e80941Smrg
149b8e80941Smrg#define MATCH_ICONV_CASE(op) \
150b8e80941Smrg   case nir_op_##op##8: \
151b8e80941Smrg   case nir_op_##op##16: \
152b8e80941Smrg   case nir_op_##op##32: \
153b8e80941Smrg   case nir_op_##op##64: \
154b8e80941Smrg      return nir_search_op_##op;
155b8e80941Smrg
156b8e80941Smrg#define MATCH_BCONV_CASE(op) \
157b8e80941Smrg   case nir_op_##op##1: \
158b8e80941Smrg   case nir_op_##op##32: \
159b8e80941Smrg      return nir_search_op_##op;
160b8e80941Smrg
161b8e80941Smrg
162b8e80941Smrg   switch (nop) {
163b8e80941Smrg   MATCH_FCONV_CASE(i2f)
164b8e80941Smrg   MATCH_FCONV_CASE(u2f)
165b8e80941Smrg   MATCH_FCONV_CASE(f2f)
166b8e80941Smrg   MATCH_ICONV_CASE(f2u)
167b8e80941Smrg   MATCH_ICONV_CASE(f2i)
168b8e80941Smrg   MATCH_ICONV_CASE(u2u)
169b8e80941Smrg   MATCH_ICONV_CASE(i2i)
170b8e80941Smrg   MATCH_FCONV_CASE(b2f)
171b8e80941Smrg   MATCH_ICONV_CASE(b2i)
172b8e80941Smrg   MATCH_BCONV_CASE(i2b)
173b8e80941Smrg   MATCH_BCONV_CASE(f2b)
174b8e80941Smrg   default:
175b8e80941Smrg      return nop;
176b8e80941Smrg   }
177b8e80941Smrg
178b8e80941Smrg#undef MATCH_FCONV_CASE
179b8e80941Smrg#undef MATCH_ICONV_CASE
180b8e80941Smrg#undef MATCH_BCONV_CASE
181b8e80941Smrg}
182b8e80941Smrg
183b8e80941Smrgstatic nir_op
184b8e80941Smrgnir_op_for_search_op(uint16_t sop, unsigned bit_size)
185b8e80941Smrg{
186b8e80941Smrg   if (sop <= nir_last_opcode)
187b8e80941Smrg      return sop;
188b8e80941Smrg
189b8e80941Smrg#define RET_FCONV_CASE(op) \
190b8e80941Smrg   case nir_search_op_##op: \
191b8e80941Smrg      switch (bit_size) { \
192b8e80941Smrg      case 16: return nir_op_##op##16; \
193b8e80941Smrg      case 32: return nir_op_##op##32; \
194b8e80941Smrg      case 64: return nir_op_##op##64; \
195b8e80941Smrg      default: unreachable("Invalid bit size"); \
196b8e80941Smrg      }
197b8e80941Smrg
198b8e80941Smrg#define RET_ICONV_CASE(op) \
199b8e80941Smrg   case nir_search_op_##op: \
200b8e80941Smrg      switch (bit_size) { \
201b8e80941Smrg      case 8:  return nir_op_##op##8; \
202b8e80941Smrg      case 16: return nir_op_##op##16; \
203b8e80941Smrg      case 32: return nir_op_##op##32; \
204b8e80941Smrg      case 64: return nir_op_##op##64; \
205b8e80941Smrg      default: unreachable("Invalid bit size"); \
206b8e80941Smrg      }
207b8e80941Smrg
208b8e80941Smrg#define RET_BCONV_CASE(op) \
209b8e80941Smrg   case nir_search_op_##op: \
210b8e80941Smrg      switch (bit_size) { \
211b8e80941Smrg      case 1: return nir_op_##op##1; \
212b8e80941Smrg      case 32: return nir_op_##op##32; \
213b8e80941Smrg      default: unreachable("Invalid bit size"); \
214b8e80941Smrg      }
215b8e80941Smrg
216b8e80941Smrg   switch (sop) {
217b8e80941Smrg   RET_FCONV_CASE(i2f)
218b8e80941Smrg   RET_FCONV_CASE(u2f)
219b8e80941Smrg   RET_FCONV_CASE(f2f)
220b8e80941Smrg   RET_ICONV_CASE(f2u)
221b8e80941Smrg   RET_ICONV_CASE(f2i)
222b8e80941Smrg   RET_ICONV_CASE(u2u)
223b8e80941Smrg   RET_ICONV_CASE(i2i)
224b8e80941Smrg   RET_FCONV_CASE(b2f)
225b8e80941Smrg   RET_ICONV_CASE(b2i)
226b8e80941Smrg   RET_BCONV_CASE(i2b)
227b8e80941Smrg   RET_BCONV_CASE(f2b)
228b8e80941Smrg   default:
229b8e80941Smrg      unreachable("Invalid nir_search_op");
230b8e80941Smrg   }
231b8e80941Smrg
232b8e80941Smrg#undef RET_FCONV_CASE
233b8e80941Smrg#undef RET_ICONV_CASE
234b8e80941Smrg#undef RET_BCONV_CASE
235b8e80941Smrg}
236b8e80941Smrg
237b8e80941Smrgstatic bool
238b8e80941Smrgmatch_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
239b8e80941Smrg            unsigned num_components, const uint8_t *swizzle,
240b8e80941Smrg            struct match_state *state)
241b8e80941Smrg{
242b8e80941Smrg   uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
243b8e80941Smrg
244b8e80941Smrg   /* Searching only works on SSA values because, if it's not SSA, we can't
245b8e80941Smrg    * know if the value changed between one instance of that value in the
246b8e80941Smrg    * expression and another.  Also, the replace operation will place reads of
247b8e80941Smrg    * that value right before the last instruction in the expression we're
248b8e80941Smrg    * replacing so those reads will happen after the original reads and may
249b8e80941Smrg    * not be valid if they're register reads.
250b8e80941Smrg    */
251b8e80941Smrg   assert(instr->src[src].src.is_ssa);
252b8e80941Smrg
253b8e80941Smrg   /* If the source is an explicitly sized source, then we need to reset
254b8e80941Smrg    * both the number of components and the swizzle.
255b8e80941Smrg    */
256b8e80941Smrg   if (nir_op_infos[instr->op].input_sizes[src] != 0) {
257b8e80941Smrg      num_components = nir_op_infos[instr->op].input_sizes[src];
258b8e80941Smrg      swizzle = identity_swizzle;
259b8e80941Smrg   }
260b8e80941Smrg
261b8e80941Smrg   for (unsigned i = 0; i < num_components; ++i)
262b8e80941Smrg      new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
263b8e80941Smrg
264b8e80941Smrg   /* If the value has a specific bit size and it doesn't match, bail */
265b8e80941Smrg   if (value->bit_size > 0 &&
266b8e80941Smrg       nir_src_bit_size(instr->src[src].src) != value->bit_size)
267b8e80941Smrg      return false;
268b8e80941Smrg
269b8e80941Smrg   switch (value->type) {
270b8e80941Smrg   case nir_search_value_expression:
271b8e80941Smrg      if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
272b8e80941Smrg         return false;
273b8e80941Smrg
274b8e80941Smrg      return match_expression(nir_search_value_as_expression(value),
275b8e80941Smrg                              nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
276b8e80941Smrg                              num_components, new_swizzle, state);
277b8e80941Smrg
278b8e80941Smrg   case nir_search_value_variable: {
279b8e80941Smrg      nir_search_variable *var = nir_search_value_as_variable(value);
280b8e80941Smrg      assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
281b8e80941Smrg
282b8e80941Smrg      if (state->variables_seen & (1 << var->variable)) {
283b8e80941Smrg         if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
284b8e80941Smrg            return false;
285b8e80941Smrg
286b8e80941Smrg         assert(!instr->src[src].abs && !instr->src[src].negate);
287b8e80941Smrg
288b8e80941Smrg         for (unsigned i = 0; i < num_components; ++i) {
289b8e80941Smrg            if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
290b8e80941Smrg               return false;
291b8e80941Smrg         }
292b8e80941Smrg
293b8e80941Smrg         return true;
294b8e80941Smrg      } else {
295b8e80941Smrg         if (var->is_constant &&
296b8e80941Smrg             instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
297b8e80941Smrg            return false;
298b8e80941Smrg
299b8e80941Smrg         if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
300b8e80941Smrg            return false;
301b8e80941Smrg
302b8e80941Smrg         if (var->type != nir_type_invalid &&
303b8e80941Smrg             !src_is_type(instr->src[src].src, var->type))
304b8e80941Smrg            return false;
305b8e80941Smrg
306b8e80941Smrg         state->variables_seen |= (1 << var->variable);
307b8e80941Smrg         state->variables[var->variable].src = instr->src[src].src;
308b8e80941Smrg         state->variables[var->variable].abs = false;
309b8e80941Smrg         state->variables[var->variable].negate = false;
310b8e80941Smrg
311b8e80941Smrg         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
312b8e80941Smrg            if (i < num_components)
313b8e80941Smrg               state->variables[var->variable].swizzle[i] = new_swizzle[i];
314b8e80941Smrg            else
315b8e80941Smrg               state->variables[var->variable].swizzle[i] = 0;
316b8e80941Smrg         }
317b8e80941Smrg
318b8e80941Smrg         return true;
319b8e80941Smrg      }
320b8e80941Smrg   }
321b8e80941Smrg
322b8e80941Smrg   case nir_search_value_constant: {
323b8e80941Smrg      nir_search_constant *const_val = nir_search_value_as_constant(value);
324b8e80941Smrg
325b8e80941Smrg      if (!nir_src_is_const(instr->src[src].src))
326b8e80941Smrg         return false;
327b8e80941Smrg
328b8e80941Smrg      switch (const_val->type) {
329b8e80941Smrg      case nir_type_float:
330b8e80941Smrg         for (unsigned i = 0; i < num_components; ++i) {
331b8e80941Smrg            double val = nir_src_comp_as_float(instr->src[src].src,
332b8e80941Smrg                                               new_swizzle[i]);
333b8e80941Smrg            if (val != const_val->data.d)
334b8e80941Smrg               return false;
335b8e80941Smrg         }
336b8e80941Smrg         return true;
337b8e80941Smrg
338b8e80941Smrg      case nir_type_int:
339b8e80941Smrg      case nir_type_uint:
340b8e80941Smrg      case nir_type_bool: {
341b8e80941Smrg         unsigned bit_size = nir_src_bit_size(instr->src[src].src);
342b8e80941Smrg         uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
343b8e80941Smrg         for (unsigned i = 0; i < num_components; ++i) {
344b8e80941Smrg            uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
345b8e80941Smrg                                                new_swizzle[i]);
346b8e80941Smrg            if ((val & mask) != (const_val->data.u & mask))
347b8e80941Smrg               return false;
348b8e80941Smrg         }
349b8e80941Smrg         return true;
350b8e80941Smrg      }
351b8e80941Smrg
352b8e80941Smrg      default:
353b8e80941Smrg         unreachable("Invalid alu source type");
354b8e80941Smrg      }
355b8e80941Smrg   }
356b8e80941Smrg
357b8e80941Smrg   default:
358b8e80941Smrg      unreachable("Invalid search value type");
359b8e80941Smrg   }
360b8e80941Smrg}
361b8e80941Smrg
362b8e80941Smrgstatic bool
363b8e80941Smrgmatch_expression(const nir_search_expression *expr, nir_alu_instr *instr,
364b8e80941Smrg                 unsigned num_components, const uint8_t *swizzle,
365b8e80941Smrg                 struct match_state *state)
366b8e80941Smrg{
367b8e80941Smrg   if (expr->cond && !expr->cond(instr))
368b8e80941Smrg      return false;
369b8e80941Smrg
370b8e80941Smrg   if (!nir_op_matches_search_op(instr->op, expr->opcode))
371b8e80941Smrg      return false;
372b8e80941Smrg
373b8e80941Smrg   assert(instr->dest.dest.is_ssa);
374b8e80941Smrg
375b8e80941Smrg   if (expr->value.bit_size > 0 &&
376b8e80941Smrg       instr->dest.dest.ssa.bit_size != expr->value.bit_size)
377b8e80941Smrg      return false;
378b8e80941Smrg
379b8e80941Smrg   state->inexact_match = expr->inexact || state->inexact_match;
380b8e80941Smrg   state->has_exact_alu = instr->exact || state->has_exact_alu;
381b8e80941Smrg   if (state->inexact_match && state->has_exact_alu)
382b8e80941Smrg      return false;
383b8e80941Smrg
384b8e80941Smrg   assert(!instr->dest.saturate);
385b8e80941Smrg   assert(nir_op_infos[instr->op].num_inputs > 0);
386b8e80941Smrg
387b8e80941Smrg   /* If we have an explicitly sized destination, we can only handle the
388b8e80941Smrg    * identity swizzle.  While dot(vec3(a, b, c).zxy) is a valid
389b8e80941Smrg    * expression, we don't have the information right now to propagate that
390b8e80941Smrg    * swizzle through.  We can only properly propagate swizzles if the
391b8e80941Smrg    * instruction is vectorized.
392b8e80941Smrg    */
393b8e80941Smrg   if (nir_op_infos[instr->op].output_size != 0) {
394b8e80941Smrg      for (unsigned i = 0; i < num_components; i++) {
395b8e80941Smrg         if (swizzle[i] != i)
396b8e80941Smrg            return false;
397b8e80941Smrg      }
398b8e80941Smrg   }
399b8e80941Smrg
400b8e80941Smrg   /* If this is a commutative expression and it's one of the first few, look
401b8e80941Smrg    * up its direction for the current search operation.  We'll use that value
402b8e80941Smrg    * to possibly flip the sources for the match.
403b8e80941Smrg    */
404b8e80941Smrg   unsigned comm_op_flip =
405b8e80941Smrg      (expr->comm_expr_idx >= 0 &&
406b8e80941Smrg       expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
407b8e80941Smrg      ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
408b8e80941Smrg
409b8e80941Smrg   bool matched = true;
410b8e80941Smrg   for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
411b8e80941Smrg      if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
412b8e80941Smrg                       num_components, swizzle, state)) {
413b8e80941Smrg         matched = false;
414b8e80941Smrg         break;
415b8e80941Smrg      }
416b8e80941Smrg   }
417b8e80941Smrg
418b8e80941Smrg   return matched;
419b8e80941Smrg}
420b8e80941Smrg
421b8e80941Smrgstatic unsigned
422b8e80941Smrgreplace_bitsize(const nir_search_value *value, unsigned search_bitsize,
423b8e80941Smrg                struct match_state *state)
424b8e80941Smrg{
425b8e80941Smrg   if (value->bit_size > 0)
426b8e80941Smrg      return value->bit_size;
427b8e80941Smrg   if (value->bit_size < 0)
428b8e80941Smrg      return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
429b8e80941Smrg   return search_bitsize;
430b8e80941Smrg}
431b8e80941Smrg
432b8e80941Smrgstatic nir_alu_src
433b8e80941Smrgconstruct_value(nir_builder *build,
434b8e80941Smrg                const nir_search_value *value,
435b8e80941Smrg                unsigned num_components, unsigned search_bitsize,
436b8e80941Smrg                struct match_state *state,
437b8e80941Smrg                nir_instr *instr)
438b8e80941Smrg{
439b8e80941Smrg   switch (value->type) {
440b8e80941Smrg   case nir_search_value_expression: {
441b8e80941Smrg      const nir_search_expression *expr = nir_search_value_as_expression(value);
442b8e80941Smrg      unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
443b8e80941Smrg      nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
444b8e80941Smrg
445b8e80941Smrg      if (nir_op_infos[op].output_size != 0)
446b8e80941Smrg         num_components = nir_op_infos[op].output_size;
447b8e80941Smrg
448b8e80941Smrg      nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
449b8e80941Smrg      nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
450b8e80941Smrg                        dst_bit_size, NULL);
451b8e80941Smrg      alu->dest.write_mask = (1 << num_components) - 1;
452b8e80941Smrg      alu->dest.saturate = false;
453b8e80941Smrg
454b8e80941Smrg      /* We have no way of knowing what values in a given search expression
455b8e80941Smrg       * map to a particular replacement value.  Therefore, if the
456b8e80941Smrg       * expression we are replacing has any exact values, the entire
457b8e80941Smrg       * replacement should be exact.
458b8e80941Smrg       */
459b8e80941Smrg      alu->exact = state->has_exact_alu;
460b8e80941Smrg
461b8e80941Smrg      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
462b8e80941Smrg         /* If the source is an explicitly sized source, then we need to reset
463b8e80941Smrg          * the number of components to match.
464b8e80941Smrg          */
465b8e80941Smrg         if (nir_op_infos[alu->op].input_sizes[i] != 0)
466b8e80941Smrg            num_components = nir_op_infos[alu->op].input_sizes[i];
467b8e80941Smrg
468b8e80941Smrg         alu->src[i] = construct_value(build, expr->srcs[i],
469b8e80941Smrg                                       num_components, search_bitsize,
470b8e80941Smrg                                       state, instr);
471b8e80941Smrg      }
472b8e80941Smrg
473b8e80941Smrg      nir_builder_instr_insert(build, &alu->instr);
474b8e80941Smrg
475b8e80941Smrg      nir_alu_src val;
476b8e80941Smrg      val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
477b8e80941Smrg      val.negate = false;
478b8e80941Smrg      val.abs = false,
479b8e80941Smrg      memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
480b8e80941Smrg
481b8e80941Smrg      return val;
482b8e80941Smrg   }
483b8e80941Smrg
484b8e80941Smrg   case nir_search_value_variable: {
485b8e80941Smrg      const nir_search_variable *var = nir_search_value_as_variable(value);
486b8e80941Smrg      assert(state->variables_seen & (1 << var->variable));
487b8e80941Smrg
488b8e80941Smrg      nir_alu_src val = { NIR_SRC_INIT };
489b8e80941Smrg      nir_alu_src_copy(&val, &state->variables[var->variable],
490b8e80941Smrg                       (void *)build->shader);
491b8e80941Smrg      assert(!var->is_constant);
492b8e80941Smrg
493b8e80941Smrg      return val;
494b8e80941Smrg   }
495b8e80941Smrg
496b8e80941Smrg   case nir_search_value_constant: {
497b8e80941Smrg      const nir_search_constant *c = nir_search_value_as_constant(value);
498b8e80941Smrg      unsigned bit_size = replace_bitsize(value, search_bitsize, state);
499b8e80941Smrg
500b8e80941Smrg      nir_ssa_def *cval;
501b8e80941Smrg      switch (c->type) {
502b8e80941Smrg      case nir_type_float:
503b8e80941Smrg         cval = nir_imm_floatN_t(build, c->data.d, bit_size);
504b8e80941Smrg         break;
505b8e80941Smrg
506b8e80941Smrg      case nir_type_int:
507b8e80941Smrg      case nir_type_uint:
508b8e80941Smrg         cval = nir_imm_intN_t(build, c->data.i, bit_size);
509b8e80941Smrg         break;
510b8e80941Smrg
511b8e80941Smrg      case nir_type_bool:
512b8e80941Smrg         cval = nir_imm_boolN_t(build, c->data.u, bit_size);
513b8e80941Smrg         break;
514b8e80941Smrg
515b8e80941Smrg      default:
516b8e80941Smrg         unreachable("Invalid alu source type");
517b8e80941Smrg      }
518b8e80941Smrg
519b8e80941Smrg      nir_alu_src val;
520b8e80941Smrg      val.src = nir_src_for_ssa(cval);
521b8e80941Smrg      val.negate = false;
522b8e80941Smrg      val.abs = false,
523b8e80941Smrg      memset(val.swizzle, 0, sizeof val.swizzle);
524b8e80941Smrg
525b8e80941Smrg      return val;
526b8e80941Smrg   }
527b8e80941Smrg
528b8e80941Smrg   default:
529b8e80941Smrg      unreachable("Invalid search value type");
530b8e80941Smrg   }
531b8e80941Smrg}
532b8e80941Smrg
533b8e80941SmrgMAYBE_UNUSED static void dump_value(const nir_search_value *val)
534b8e80941Smrg{
535b8e80941Smrg   switch (val->type) {
536b8e80941Smrg   case nir_search_value_constant: {
537b8e80941Smrg      const nir_search_constant *sconst = nir_search_value_as_constant(val);
538b8e80941Smrg      switch (sconst->type) {
539b8e80941Smrg      case nir_type_float:
540b8e80941Smrg         printf("%f", sconst->data.d);
541b8e80941Smrg         break;
542b8e80941Smrg      case nir_type_int:
543b8e80941Smrg         printf("%"PRId64, sconst->data.i);
544b8e80941Smrg         break;
545b8e80941Smrg      case nir_type_uint:
546b8e80941Smrg         printf("0x%"PRIx64, sconst->data.u);
547b8e80941Smrg         break;
548b8e80941Smrg      default:
549b8e80941Smrg         unreachable("bad const type");
550b8e80941Smrg      }
551b8e80941Smrg      break;
552b8e80941Smrg   }
553b8e80941Smrg
554b8e80941Smrg   case nir_search_value_variable: {
555b8e80941Smrg      const nir_search_variable *var = nir_search_value_as_variable(val);
556b8e80941Smrg      if (var->is_constant)
557b8e80941Smrg         printf("#");
558b8e80941Smrg      printf("%c", var->variable + 'a');
559b8e80941Smrg      break;
560b8e80941Smrg   }
561b8e80941Smrg
562b8e80941Smrg   case nir_search_value_expression: {
563b8e80941Smrg      const nir_search_expression *expr = nir_search_value_as_expression(val);
564b8e80941Smrg      printf("(");
565b8e80941Smrg      if (expr->inexact)
566b8e80941Smrg         printf("~");
567b8e80941Smrg      switch (expr->opcode) {
568b8e80941Smrg#define CASE(n) \
569b8e80941Smrg      case nir_search_op_##n: printf(#n); break;
570b8e80941Smrg      CASE(f2b)
571b8e80941Smrg      CASE(b2f)
572b8e80941Smrg      CASE(b2i)
573b8e80941Smrg      CASE(i2b)
574b8e80941Smrg      CASE(i2i)
575b8e80941Smrg      CASE(f2i)
576b8e80941Smrg      CASE(i2f)
577b8e80941Smrg#undef CASE
578b8e80941Smrg      default:
579b8e80941Smrg         printf("%s", nir_op_infos[expr->opcode].name);
580b8e80941Smrg      }
581b8e80941Smrg
582b8e80941Smrg      unsigned num_srcs = 1;
583b8e80941Smrg      if (expr->opcode <= nir_last_opcode)
584b8e80941Smrg         num_srcs = nir_op_infos[expr->opcode].num_inputs;
585b8e80941Smrg
586b8e80941Smrg      for (unsigned i = 0; i < num_srcs; i++) {
587b8e80941Smrg         printf(" ");
588b8e80941Smrg         dump_value(expr->srcs[i]);
589b8e80941Smrg      }
590b8e80941Smrg
591b8e80941Smrg      printf(")");
592b8e80941Smrg      break;
593b8e80941Smrg   }
594b8e80941Smrg   }
595b8e80941Smrg
596b8e80941Smrg   if (val->bit_size > 0)
597b8e80941Smrg      printf("@%d", val->bit_size);
598b8e80941Smrg}
599b8e80941Smrg
600b8e80941Smrgnir_ssa_def *
601b8e80941Smrgnir_replace_instr(nir_builder *build, nir_alu_instr *instr,
602b8e80941Smrg                  const nir_search_expression *search,
603b8e80941Smrg                  const nir_search_value *replace)
604b8e80941Smrg{
605b8e80941Smrg   uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
606b8e80941Smrg
607b8e80941Smrg   for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
608b8e80941Smrg      swizzle[i] = i;
609b8e80941Smrg
610b8e80941Smrg   assert(instr->dest.dest.is_ssa);
611b8e80941Smrg
612b8e80941Smrg   struct match_state state;
613b8e80941Smrg   state.inexact_match = false;
614b8e80941Smrg   state.has_exact_alu = false;
615b8e80941Smrg
616b8e80941Smrg   unsigned comm_expr_combinations =
617b8e80941Smrg      1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
618b8e80941Smrg
619b8e80941Smrg   bool found = false;
620b8e80941Smrg   for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
621b8e80941Smrg      /* The bitfield of directions is just the current iteration.  Hooray for
622b8e80941Smrg       * binary.
623b8e80941Smrg       */
624b8e80941Smrg      state.comm_op_direction = comb;
625b8e80941Smrg      state.variables_seen = 0;
626b8e80941Smrg
627b8e80941Smrg      if (match_expression(search, instr,
628b8e80941Smrg                           instr->dest.dest.ssa.num_components,
629b8e80941Smrg                           swizzle, &state)) {
630b8e80941Smrg         found = true;
631b8e80941Smrg         break;
632b8e80941Smrg      }
633b8e80941Smrg   }
634b8e80941Smrg   if (!found)
635b8e80941Smrg      return NULL;
636b8e80941Smrg
637b8e80941Smrg#if 0
638b8e80941Smrg   printf("matched: ");
639b8e80941Smrg   dump_value(&search->value);
640b8e80941Smrg   printf(" -> ");
641b8e80941Smrg   dump_value(replace);
642b8e80941Smrg   printf(" ssa_%d\n", instr->dest.dest.ssa.index);
643b8e80941Smrg#endif
644b8e80941Smrg
645b8e80941Smrg   build->cursor = nir_before_instr(&instr->instr);
646b8e80941Smrg
647b8e80941Smrg   nir_alu_src val = construct_value(build, replace,
648b8e80941Smrg                                     instr->dest.dest.ssa.num_components,
649b8e80941Smrg                                     instr->dest.dest.ssa.bit_size,
650b8e80941Smrg                                     &state, &instr->instr);
651b8e80941Smrg
652b8e80941Smrg   /* Inserting a mov may be unnecessary.  However, it's much easier to
653b8e80941Smrg    * simply let copy propagation clean this up than to try to go through
654b8e80941Smrg    * and rewrite swizzles ourselves.
655b8e80941Smrg    */
656b8e80941Smrg   nir_ssa_def *ssa_val =
657b8e80941Smrg      nir_imov_alu(build, val, instr->dest.dest.ssa.num_components);
658b8e80941Smrg   nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
659b8e80941Smrg
660b8e80941Smrg   /* We know this one has no more uses because we just rewrote them all,
661b8e80941Smrg    * so we can remove it.  The rest of the matched expression, however, we
662b8e80941Smrg    * don't know so much about.  We'll just let dead code clean them up.
663b8e80941Smrg    */
664b8e80941Smrg   nir_instr_remove(&instr->instr);
665b8e80941Smrg
666b8e80941Smrg   return ssa_val;
667b8e80941Smrg}
668