1b8e80941Smrg#include "brw_nir.h"
2b8e80941Smrg
3b8e80941Smrg#include "nir.h"
4b8e80941Smrg#include "nir_builder.h"
5b8e80941Smrg#include "nir_search.h"
6b8e80941Smrg#include "nir_search_helpers.h"
7b8e80941Smrg
8b8e80941Smrg#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
9b8e80941Smrg#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
10b8e80941Smrg
11b8e80941Smrgstruct transform {
12b8e80941Smrg   const nir_search_expression *search;
13b8e80941Smrg   const nir_search_value *replace;
14b8e80941Smrg   unsigned condition_offset;
15b8e80941Smrg};
16b8e80941Smrg
17b8e80941Smrgstruct per_op_table {
18b8e80941Smrg   const uint16_t *filter;
19b8e80941Smrg   unsigned num_filtered_states;
20b8e80941Smrg   const uint16_t *table;
21b8e80941Smrg};
22b8e80941Smrg
23b8e80941Smrg/* Note: these must match the start states created in
24b8e80941Smrg * TreeAutomaton._build_table()
25b8e80941Smrg */
26b8e80941Smrg
27b8e80941Smrg/* WILDCARD_STATE = 0 is set by zeroing the state array */
28b8e80941Smrgstatic const uint16_t CONST_STATE = 1;
29b8e80941Smrg
30b8e80941Smrg#endif
31b8e80941Smrg
32b8e80941Smrg
33b8e80941Smrg   static const nir_search_variable search0_0 = {
34b8e80941Smrg   { nir_search_value_variable, -1 },
35b8e80941Smrg   0, /* x */
36b8e80941Smrg   false,
37b8e80941Smrg   nir_type_invalid,
38b8e80941Smrg   NULL,
39b8e80941Smrg};
40b8e80941Smrgstatic const nir_search_expression search0 = {
41b8e80941Smrg   { nir_search_value_expression, -1 },
42b8e80941Smrg   false,
43b8e80941Smrg   -1, 0,
44b8e80941Smrg   nir_op_fsin,
45b8e80941Smrg   { &search0_0.value },
46b8e80941Smrg   NULL,
47b8e80941Smrg};
48b8e80941Smrg
49b8e80941Smrg   /* replace0_0_0 -> search0_0 in the cache */
50b8e80941Smrg/* replace0_0 -> search0 in the cache */
51b8e80941Smrg
52b8e80941Smrgstatic const nir_search_constant replace0_1 = {
53b8e80941Smrg   { nir_search_value_constant, -1 },
54b8e80941Smrg   nir_type_float, { 0x3fefffc115df6556 /* 0.99997 */ },
55b8e80941Smrg};
56b8e80941Smrgstatic const nir_search_expression replace0 = {
57b8e80941Smrg   { nir_search_value_expression, -1 },
58b8e80941Smrg   false,
59b8e80941Smrg   0, 1,
60b8e80941Smrg   nir_op_fmul,
61b8e80941Smrg   { &search0.value, &replace0_1.value },
62b8e80941Smrg   NULL,
63b8e80941Smrg};
64b8e80941Smrg
65b8e80941Smrg   /* search1_0 -> search0_0 in the cache */
66b8e80941Smrgstatic const nir_search_expression search1 = {
67b8e80941Smrg   { nir_search_value_expression, -1 },
68b8e80941Smrg   false,
69b8e80941Smrg   -1, 0,
70b8e80941Smrg   nir_op_fcos,
71b8e80941Smrg   { &search0_0.value },
72b8e80941Smrg   NULL,
73b8e80941Smrg};
74b8e80941Smrg
75b8e80941Smrg   /* replace1_0_0 -> search0_0 in the cache */
76b8e80941Smrg/* replace1_0 -> search1 in the cache */
77b8e80941Smrg
78b8e80941Smrg/* replace1_1 -> replace0_1 in the cache */
79b8e80941Smrgstatic const nir_search_expression replace1 = {
80b8e80941Smrg   { nir_search_value_expression, -1 },
81b8e80941Smrg   false,
82b8e80941Smrg   0, 1,
83b8e80941Smrg   nir_op_fmul,
84b8e80941Smrg   { &search1.value, &replace0_1.value },
85b8e80941Smrg   NULL,
86b8e80941Smrg};
87b8e80941Smrg
88b8e80941Smrg
89b8e80941Smrgstatic const struct transform brw_nir_apply_trig_workarounds_state2_xforms[] = {
90b8e80941Smrg  { &search0, &replace0.value, 0 },
91b8e80941Smrg};
92b8e80941Smrgstatic const struct transform brw_nir_apply_trig_workarounds_state3_xforms[] = {
93b8e80941Smrg  { &search1, &replace1.value, 0 },
94b8e80941Smrg};
95b8e80941Smrg
96b8e80941Smrgstatic const struct per_op_table brw_nir_apply_trig_workarounds_table[nir_num_search_ops] = {
97b8e80941Smrg   [nir_op_fsin] = {
98b8e80941Smrg      .filter = (uint16_t []) {
99b8e80941Smrg         0,
100b8e80941Smrg         0,
101b8e80941Smrg         0,
102b8e80941Smrg         0,
103b8e80941Smrg      },
104b8e80941Smrg
105b8e80941Smrg      .num_filtered_states = 1,
106b8e80941Smrg      .table = (uint16_t []) {
107b8e80941Smrg
108b8e80941Smrg         2,
109b8e80941Smrg      },
110b8e80941Smrg   },
111b8e80941Smrg   [nir_op_fcos] = {
112b8e80941Smrg      .filter = (uint16_t []) {
113b8e80941Smrg         0,
114b8e80941Smrg         0,
115b8e80941Smrg         0,
116b8e80941Smrg         0,
117b8e80941Smrg      },
118b8e80941Smrg
119b8e80941Smrg      .num_filtered_states = 1,
120b8e80941Smrg      .table = (uint16_t []) {
121b8e80941Smrg
122b8e80941Smrg         3,
123b8e80941Smrg      },
124b8e80941Smrg   },
125b8e80941Smrg};
126b8e80941Smrg
127b8e80941Smrgstatic void
128b8e80941Smrgbrw_nir_apply_trig_workarounds_pre_block(nir_block *block, uint16_t *states)
129b8e80941Smrg{
130b8e80941Smrg   nir_foreach_instr(instr, block) {
131b8e80941Smrg      switch (instr->type) {
132b8e80941Smrg      case nir_instr_type_alu: {
133b8e80941Smrg         nir_alu_instr *alu = nir_instr_as_alu(instr);
134b8e80941Smrg         nir_op op = alu->op;
135b8e80941Smrg         uint16_t search_op = nir_search_op_for_nir_op(op);
136b8e80941Smrg         const struct per_op_table *tbl = &brw_nir_apply_trig_workarounds_table[search_op];
137b8e80941Smrg         if (tbl->num_filtered_states == 0)
138b8e80941Smrg            continue;
139b8e80941Smrg
140b8e80941Smrg         /* Calculate the index into the transition table. Note the index
141b8e80941Smrg          * calculated must match the iteration order of Python's
142b8e80941Smrg          * itertools.product(), which was used to emit the transition
143b8e80941Smrg          * table.
144b8e80941Smrg          */
145b8e80941Smrg         uint16_t index = 0;
146b8e80941Smrg         for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
147b8e80941Smrg            index *= tbl->num_filtered_states;
148b8e80941Smrg            index += tbl->filter[states[alu->src[i].src.ssa->index]];
149b8e80941Smrg         }
150b8e80941Smrg         states[alu->dest.dest.ssa.index] = tbl->table[index];
151b8e80941Smrg         break;
152b8e80941Smrg      }
153b8e80941Smrg
154b8e80941Smrg      case nir_instr_type_load_const: {
155b8e80941Smrg         nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
156b8e80941Smrg         states[load_const->def.index] = CONST_STATE;
157b8e80941Smrg         break;
158b8e80941Smrg      }
159b8e80941Smrg
160b8e80941Smrg      default:
161b8e80941Smrg         break;
162b8e80941Smrg      }
163b8e80941Smrg   }
164b8e80941Smrg}
165b8e80941Smrg
166b8e80941Smrgstatic bool
167b8e80941Smrgbrw_nir_apply_trig_workarounds_block(nir_builder *build, nir_block *block,
168b8e80941Smrg                   const uint16_t *states, const bool *condition_flags)
169b8e80941Smrg{
170b8e80941Smrg   bool progress = false;
171b8e80941Smrg
172b8e80941Smrg   nir_foreach_instr_reverse_safe(instr, block) {
173b8e80941Smrg      if (instr->type != nir_instr_type_alu)
174b8e80941Smrg         continue;
175b8e80941Smrg
176b8e80941Smrg      nir_alu_instr *alu = nir_instr_as_alu(instr);
177b8e80941Smrg      if (!alu->dest.dest.is_ssa)
178b8e80941Smrg         continue;
179b8e80941Smrg
180b8e80941Smrg      switch (states[alu->dest.dest.ssa.index]) {
181b8e80941Smrg      case 0:
182b8e80941Smrg         break;
183b8e80941Smrg      case 1:
184b8e80941Smrg         break;
185b8e80941Smrg      case 2:
186b8e80941Smrg         for (unsigned i = 0; i < ARRAY_SIZE(brw_nir_apply_trig_workarounds_state2_xforms); i++) {
187b8e80941Smrg            const struct transform *xform = &brw_nir_apply_trig_workarounds_state2_xforms[i];
188b8e80941Smrg            if (condition_flags[xform->condition_offset] &&
189b8e80941Smrg                nir_replace_instr(build, alu, xform->search, xform->replace)) {
190b8e80941Smrg               progress = true;
191b8e80941Smrg               break;
192b8e80941Smrg            }
193b8e80941Smrg         }
194b8e80941Smrg         break;
195b8e80941Smrg      case 3:
196b8e80941Smrg         for (unsigned i = 0; i < ARRAY_SIZE(brw_nir_apply_trig_workarounds_state3_xforms); i++) {
197b8e80941Smrg            const struct transform *xform = &brw_nir_apply_trig_workarounds_state3_xforms[i];
198b8e80941Smrg            if (condition_flags[xform->condition_offset] &&
199b8e80941Smrg                nir_replace_instr(build, alu, xform->search, xform->replace)) {
200b8e80941Smrg               progress = true;
201b8e80941Smrg               break;
202b8e80941Smrg            }
203b8e80941Smrg         }
204b8e80941Smrg         break;
205b8e80941Smrg      default: assert(0);
206b8e80941Smrg      }
207b8e80941Smrg   }
208b8e80941Smrg
209b8e80941Smrg   return progress;
210b8e80941Smrg}
211b8e80941Smrg
212b8e80941Smrgstatic bool
213b8e80941Smrgbrw_nir_apply_trig_workarounds_impl(nir_function_impl *impl, const bool *condition_flags)
214b8e80941Smrg{
215b8e80941Smrg   bool progress = false;
216b8e80941Smrg
217b8e80941Smrg   nir_builder build;
218b8e80941Smrg   nir_builder_init(&build, impl);
219b8e80941Smrg
220b8e80941Smrg   /* Note: it's important here that we're allocating a zeroed array, since
221b8e80941Smrg    * state 0 is the default state, which means we don't have to visit
222b8e80941Smrg    * anything other than constants and ALU instructions.
223b8e80941Smrg    */
224b8e80941Smrg   uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
225b8e80941Smrg
226b8e80941Smrg   nir_foreach_block(block, impl) {
227b8e80941Smrg      brw_nir_apply_trig_workarounds_pre_block(block, states);
228b8e80941Smrg   }
229b8e80941Smrg
230b8e80941Smrg   nir_foreach_block_reverse(block, impl) {
231b8e80941Smrg      progress |= brw_nir_apply_trig_workarounds_block(&build, block, states, condition_flags);
232b8e80941Smrg   }
233b8e80941Smrg
234b8e80941Smrg   free(states);
235b8e80941Smrg
236b8e80941Smrg   if (progress) {
237b8e80941Smrg      nir_metadata_preserve(impl, nir_metadata_block_index |
238b8e80941Smrg                                  nir_metadata_dominance);
239b8e80941Smrg    } else {
240b8e80941Smrg#ifndef NDEBUG
241b8e80941Smrg      impl->valid_metadata &= ~nir_metadata_not_properly_reset;
242b8e80941Smrg#endif
243b8e80941Smrg    }
244b8e80941Smrg
245b8e80941Smrg   return progress;
246b8e80941Smrg}
247b8e80941Smrg
248b8e80941Smrg
249b8e80941Smrgbool
250b8e80941Smrgbrw_nir_apply_trig_workarounds(nir_shader *shader)
251b8e80941Smrg{
252b8e80941Smrg   bool progress = false;
253b8e80941Smrg   bool condition_flags[1];
254b8e80941Smrg   const nir_shader_compiler_options *options = shader->options;
255b8e80941Smrg   const shader_info *info = &shader->info;
256b8e80941Smrg   (void) options;
257b8e80941Smrg   (void) info;
258b8e80941Smrg
259b8e80941Smrg   condition_flags[0] = true;
260b8e80941Smrg
261b8e80941Smrg   nir_foreach_function(function, shader) {
262b8e80941Smrg      if (function->impl)
263b8e80941Smrg         progress |= brw_nir_apply_trig_workarounds_impl(function->impl, condition_flags);
264b8e80941Smrg   }
265b8e80941Smrg
266b8e80941Smrg   return progress;
267b8e80941Smrg}
268b8e80941Smrg
269