brw_nir_trig_workarounds.c revision b8e80941
1#include "brw_nir.h"
2
3#include "nir.h"
4#include "nir_builder.h"
5#include "nir_search.h"
6#include "nir_search_helpers.h"
7
8#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
9#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
10
11struct transform {
12   const nir_search_expression *search;
13   const nir_search_value *replace;
14   unsigned condition_offset;
15};
16
17struct per_op_table {
18   const uint16_t *filter;
19   unsigned num_filtered_states;
20   const uint16_t *table;
21};
22
23/* Note: these must match the start states created in
24 * TreeAutomaton._build_table()
25 */
26
27/* WILDCARD_STATE = 0 is set by zeroing the state array */
28static const uint16_t CONST_STATE = 1;
29
30#endif
31
32
33   static const nir_search_variable search0_0 = {
34   { nir_search_value_variable, -1 },
35   0, /* x */
36   false,
37   nir_type_invalid,
38   NULL,
39};
40static const nir_search_expression search0 = {
41   { nir_search_value_expression, -1 },
42   false,
43   -1, 0,
44   nir_op_fsin,
45   { &search0_0.value },
46   NULL,
47};
48
49   /* replace0_0_0 -> search0_0 in the cache */
50/* replace0_0 -> search0 in the cache */
51
52static const nir_search_constant replace0_1 = {
53   { nir_search_value_constant, -1 },
54   nir_type_float, { 0x3fefffc115df6556 /* 0.99997 */ },
55};
56static const nir_search_expression replace0 = {
57   { nir_search_value_expression, -1 },
58   false,
59   0, 1,
60   nir_op_fmul,
61   { &search0.value, &replace0_1.value },
62   NULL,
63};
64
65   /* search1_0 -> search0_0 in the cache */
66static const nir_search_expression search1 = {
67   { nir_search_value_expression, -1 },
68   false,
69   -1, 0,
70   nir_op_fcos,
71   { &search0_0.value },
72   NULL,
73};
74
75   /* replace1_0_0 -> search0_0 in the cache */
76/* replace1_0 -> search1 in the cache */
77
78/* replace1_1 -> replace0_1 in the cache */
79static const nir_search_expression replace1 = {
80   { nir_search_value_expression, -1 },
81   false,
82   0, 1,
83   nir_op_fmul,
84   { &search1.value, &replace0_1.value },
85   NULL,
86};
87
88
89static const struct transform brw_nir_apply_trig_workarounds_state2_xforms[] = {
90  { &search0, &replace0.value, 0 },
91};
92static const struct transform brw_nir_apply_trig_workarounds_state3_xforms[] = {
93  { &search1, &replace1.value, 0 },
94};
95
96static const struct per_op_table brw_nir_apply_trig_workarounds_table[nir_num_search_ops] = {
97   [nir_op_fsin] = {
98      .filter = (uint16_t []) {
99         0,
100         0,
101         0,
102         0,
103      },
104
105      .num_filtered_states = 1,
106      .table = (uint16_t []) {
107
108         2,
109      },
110   },
111   [nir_op_fcos] = {
112      .filter = (uint16_t []) {
113         0,
114         0,
115         0,
116         0,
117      },
118
119      .num_filtered_states = 1,
120      .table = (uint16_t []) {
121
122         3,
123      },
124   },
125};
126
127static void
128brw_nir_apply_trig_workarounds_pre_block(nir_block *block, uint16_t *states)
129{
130   nir_foreach_instr(instr, block) {
131      switch (instr->type) {
132      case nir_instr_type_alu: {
133         nir_alu_instr *alu = nir_instr_as_alu(instr);
134         nir_op op = alu->op;
135         uint16_t search_op = nir_search_op_for_nir_op(op);
136         const struct per_op_table *tbl = &brw_nir_apply_trig_workarounds_table[search_op];
137         if (tbl->num_filtered_states == 0)
138            continue;
139
140         /* Calculate the index into the transition table. Note the index
141          * calculated must match the iteration order of Python's
142          * itertools.product(), which was used to emit the transition
143          * table.
144          */
145         uint16_t index = 0;
146         for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
147            index *= tbl->num_filtered_states;
148            index += tbl->filter[states[alu->src[i].src.ssa->index]];
149         }
150         states[alu->dest.dest.ssa.index] = tbl->table[index];
151         break;
152      }
153
154      case nir_instr_type_load_const: {
155         nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
156         states[load_const->def.index] = CONST_STATE;
157         break;
158      }
159
160      default:
161         break;
162      }
163   }
164}
165
166static bool
167brw_nir_apply_trig_workarounds_block(nir_builder *build, nir_block *block,
168                   const uint16_t *states, const bool *condition_flags)
169{
170   bool progress = false;
171
172   nir_foreach_instr_reverse_safe(instr, block) {
173      if (instr->type != nir_instr_type_alu)
174         continue;
175
176      nir_alu_instr *alu = nir_instr_as_alu(instr);
177      if (!alu->dest.dest.is_ssa)
178         continue;
179
180      switch (states[alu->dest.dest.ssa.index]) {
181      case 0:
182         break;
183      case 1:
184         break;
185      case 2:
186         for (unsigned i = 0; i < ARRAY_SIZE(brw_nir_apply_trig_workarounds_state2_xforms); i++) {
187            const struct transform *xform = &brw_nir_apply_trig_workarounds_state2_xforms[i];
188            if (condition_flags[xform->condition_offset] &&
189                nir_replace_instr(build, alu, xform->search, xform->replace)) {
190               progress = true;
191               break;
192            }
193         }
194         break;
195      case 3:
196         for (unsigned i = 0; i < ARRAY_SIZE(brw_nir_apply_trig_workarounds_state3_xforms); i++) {
197            const struct transform *xform = &brw_nir_apply_trig_workarounds_state3_xforms[i];
198            if (condition_flags[xform->condition_offset] &&
199                nir_replace_instr(build, alu, xform->search, xform->replace)) {
200               progress = true;
201               break;
202            }
203         }
204         break;
205      default: assert(0);
206      }
207   }
208
209   return progress;
210}
211
212static bool
213brw_nir_apply_trig_workarounds_impl(nir_function_impl *impl, const bool *condition_flags)
214{
215   bool progress = false;
216
217   nir_builder build;
218   nir_builder_init(&build, impl);
219
220   /* Note: it's important here that we're allocating a zeroed array, since
221    * state 0 is the default state, which means we don't have to visit
222    * anything other than constants and ALU instructions.
223    */
224   uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
225
226   nir_foreach_block(block, impl) {
227      brw_nir_apply_trig_workarounds_pre_block(block, states);
228   }
229
230   nir_foreach_block_reverse(block, impl) {
231      progress |= brw_nir_apply_trig_workarounds_block(&build, block, states, condition_flags);
232   }
233
234   free(states);
235
236   if (progress) {
237      nir_metadata_preserve(impl, nir_metadata_block_index |
238                                  nir_metadata_dominance);
239    } else {
240#ifndef NDEBUG
241      impl->valid_metadata &= ~nir_metadata_not_properly_reset;
242#endif
243    }
244
245   return progress;
246}
247
248
249bool
250brw_nir_apply_trig_workarounds(nir_shader *shader)
251{
252   bool progress = false;
253   bool condition_flags[1];
254   const nir_shader_compiler_options *options = shader->options;
255   const shader_info *info = &shader->info;
256   (void) options;
257   (void) info;
258
259   condition_flags[0] = true;
260
261   nir_foreach_function(function, shader) {
262      if (function->impl)
263         progress |= brw_nir_apply_trig_workarounds_impl(function->impl, condition_flags);
264   }
265
266   return progress;
267}
268
269