1#include "bifrost_nir.h"
2
3#include "nir.h"
4#include "nir_builder.h"
5#include "nir_search.h"
6#include "nir_search_helpers.h"
7
8/* What follows is NIR algebraic transform code for the following 6
9 * transforms:
10 *    ('fmul', 'a', 2.0) => ('fadd', 'a', 'a')
11 *    ('fmin', ('fmax', 'a', -1.0), 1.0) => ('fsat_signed_mali', 'a')
12 *    ('fmax', ('fmin', 'a', 1.0), -1.0) => ('fsat_signed_mali', 'a')
13 *    ('fmax', 'a', 0.0) => ('fclamp_pos_mali', 'a')
14 *    ('fabs', ('fddx', 'a')) => ('fabs', ('fddx_must_abs_mali', 'a'))
15 *    ('fabs', ('fddy', 'b')) => ('fabs', ('fddy_must_abs_mali', 'b'))
16 */
17
18
19   static const nir_search_variable search0_0 = {
20   { nir_search_value_variable, -1 },
21   0, /* a */
22   false,
23   nir_type_invalid,
24   NULL,
25   {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
26};
27
28static const nir_search_constant search0_1 = {
29   { nir_search_value_constant, -1 },
30   nir_type_float, { 0x4000000000000000 /* 2.0 */ },
31};
32static const nir_search_expression search0 = {
33   { nir_search_value_expression, -1 },
34   false, false,
35   0, 1,
36   nir_op_fmul,
37   { &search0_0.value, &search0_1.value },
38   NULL,
39};
40
41   /* replace0_0 -> search0_0 in the cache */
42
43/* replace0_1 -> search0_0 in the cache */
44static const nir_search_expression replace0 = {
45   { nir_search_value_expression, -1 },
46   false, false,
47   -1, 0,
48   nir_op_fadd,
49   { &search0_0.value, &search0_0.value },
50   NULL,
51};
52
53   /* search1_0_0 -> search0_0 in the cache */
54
55static const nir_search_constant search1_0_1 = {
56   { nir_search_value_constant, -1 },
57   nir_type_float, { 0xbff0000000000000 /* -1.0 */ },
58};
59static const nir_search_expression search1_0 = {
60   { nir_search_value_expression, -1 },
61   false, false,
62   1, 1,
63   nir_op_fmax,
64   { &search0_0.value, &search1_0_1.value },
65   NULL,
66};
67
68static const nir_search_constant search1_1 = {
69   { nir_search_value_constant, -1 },
70   nir_type_float, { 0x3ff0000000000000 /* 1.0 */ },
71};
72static const nir_search_expression search1 = {
73   { nir_search_value_expression, -1 },
74   false, false,
75   0, 2,
76   nir_op_fmin,
77   { &search1_0.value, &search1_1.value },
78   NULL,
79};
80
81   /* replace1_0 -> search0_0 in the cache */
82static const nir_search_expression replace1 = {
83   { nir_search_value_expression, -1 },
84   false, false,
85   -1, 0,
86   nir_op_fsat_signed_mali,
87   { &search0_0.value },
88   NULL,
89};
90
91   /* search2_0_0 -> search0_0 in the cache */
92
93/* search2_0_1 -> search1_1 in the cache */
94static const nir_search_expression search2_0 = {
95   { nir_search_value_expression, -1 },
96   false, false,
97   1, 1,
98   nir_op_fmin,
99   { &search0_0.value, &search1_1.value },
100   NULL,
101};
102
103/* search2_1 -> search1_0_1 in the cache */
104static const nir_search_expression search2 = {
105   { nir_search_value_expression, -1 },
106   false, false,
107   0, 2,
108   nir_op_fmax,
109   { &search2_0.value, &search1_0_1.value },
110   NULL,
111};
112
113   /* replace2_0 -> search0_0 in the cache */
114/* replace2 -> replace1 in the cache */
115
116   /* search3_0 -> search0_0 in the cache */
117
118static const nir_search_constant search3_1 = {
119   { nir_search_value_constant, -1 },
120   nir_type_float, { 0x0 /* 0.0 */ },
121};
122static const nir_search_expression search3 = {
123   { nir_search_value_expression, -1 },
124   false, false,
125   0, 1,
126   nir_op_fmax,
127   { &search0_0.value, &search3_1.value },
128   NULL,
129};
130
131   /* replace3_0 -> search0_0 in the cache */
132static const nir_search_expression replace3 = {
133   { nir_search_value_expression, -1 },
134   false, false,
135   -1, 0,
136   nir_op_fclamp_pos_mali,
137   { &search0_0.value },
138   NULL,
139};
140
141   /* search4_0_0 -> search0_0 in the cache */
142static const nir_search_expression search4_0 = {
143   { nir_search_value_expression, -1 },
144   false, false,
145   -1, 0,
146   nir_op_fddx,
147   { &search0_0.value },
148   NULL,
149};
150static const nir_search_expression search4 = {
151   { nir_search_value_expression, -1 },
152   false, false,
153   -1, 0,
154   nir_op_fabs,
155   { &search4_0.value },
156   NULL,
157};
158
159   /* replace4_0_0 -> search0_0 in the cache */
160static const nir_search_expression replace4_0 = {
161   { nir_search_value_expression, -1 },
162   false, false,
163   -1, 0,
164   nir_op_fddx_must_abs_mali,
165   { &search0_0.value },
166   NULL,
167};
168static const nir_search_expression replace4 = {
169   { nir_search_value_expression, -1 },
170   false, false,
171   -1, 0,
172   nir_op_fabs,
173   { &replace4_0.value },
174   NULL,
175};
176
177   static const nir_search_variable search5_0_0 = {
178   { nir_search_value_variable, -1 },
179   0, /* b */
180   false,
181   nir_type_invalid,
182   NULL,
183   {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
184};
185static const nir_search_expression search5_0 = {
186   { nir_search_value_expression, -1 },
187   false, false,
188   -1, 0,
189   nir_op_fddy,
190   { &search5_0_0.value },
191   NULL,
192};
193static const nir_search_expression search5 = {
194   { nir_search_value_expression, -1 },
195   false, false,
196   -1, 0,
197   nir_op_fabs,
198   { &search5_0.value },
199   NULL,
200};
201
202   /* replace5_0_0 -> search5_0_0 in the cache */
203static const nir_search_expression replace5_0 = {
204   { nir_search_value_expression, -1 },
205   false, false,
206   -1, 0,
207   nir_op_fddy_must_abs_mali,
208   { &search5_0_0.value },
209   NULL,
210};
211static const nir_search_expression replace5 = {
212   { nir_search_value_expression, -1 },
213   false, false,
214   -1, 0,
215   nir_op_fabs,
216   { &replace5_0.value },
217   NULL,
218};
219
220
221static const struct transform bifrost_nir_lower_algebraic_late_state2_xforms[] = {
222  { &search0, &replace0.value, 0 },
223};
224static const struct transform bifrost_nir_lower_algebraic_late_state4_xforms[] = {
225  { &search3, &replace3.value, 0 },
226};
227static const struct transform bifrost_nir_lower_algebraic_late_state7_xforms[] = {
228  { &search2, &replace1.value, 0 },
229  { &search3, &replace3.value, 0 },
230};
231static const struct transform bifrost_nir_lower_algebraic_late_state8_xforms[] = {
232  { &search1, &replace1.value, 0 },
233};
234static const struct transform bifrost_nir_lower_algebraic_late_state9_xforms[] = {
235  { &search4, &replace4.value, 0 },
236};
237static const struct transform bifrost_nir_lower_algebraic_late_state10_xforms[] = {
238  { &search5, &replace5.value, 0 },
239};
240
241static const struct per_op_table bifrost_nir_lower_algebraic_late_table[nir_num_search_ops] = {
242   [nir_op_fmul] = {
243      .filter = (uint16_t []) {
244         0,
245         1,
246         0,
247         0,
248         0,
249         0,
250         0,
251         0,
252         0,
253         0,
254         0,
255      },
256
257      .num_filtered_states = 2,
258      .table = (uint16_t []) {
259
260         0,
261         2,
262         2,
263         2,
264      },
265   },
266   [nir_op_fmin] = {
267      .filter = (uint16_t []) {
268         0,
269         1,
270         0,
271         0,
272         2,
273         0,
274         0,
275         2,
276         0,
277         0,
278         0,
279      },
280
281      .num_filtered_states = 3,
282      .table = (uint16_t []) {
283
284         0,
285         3,
286         0,
287         3,
288         3,
289         8,
290         0,
291         8,
292         0,
293      },
294   },
295   [nir_op_fmax] = {
296      .filter = (uint16_t []) {
297         0,
298         1,
299         0,
300         2,
301         0,
302         0,
303         0,
304         0,
305         2,
306         0,
307         0,
308      },
309
310      .num_filtered_states = 3,
311      .table = (uint16_t []) {
312
313         0,
314         4,
315         0,
316         4,
317         4,
318         7,
319         0,
320         7,
321         0,
322      },
323   },
324   [nir_op_fabs] = {
325      .filter = (uint16_t []) {
326         0,
327         0,
328         0,
329         0,
330         0,
331         1,
332         2,
333         0,
334         0,
335         0,
336         0,
337      },
338
339      .num_filtered_states = 3,
340      .table = (uint16_t []) {
341
342         0,
343         9,
344         10,
345      },
346   },
347   [nir_op_fddx] = {
348      .filter = (uint16_t []) {
349         0,
350         0,
351         0,
352         0,
353         0,
354         0,
355         0,
356         0,
357         0,
358         0,
359         0,
360      },
361
362      .num_filtered_states = 1,
363      .table = (uint16_t []) {
364
365         5,
366      },
367   },
368   [nir_op_fddy] = {
369      .filter = (uint16_t []) {
370         0,
371         0,
372         0,
373         0,
374         0,
375         0,
376         0,
377         0,
378         0,
379         0,
380         0,
381      },
382
383      .num_filtered_states = 1,
384      .table = (uint16_t []) {
385
386         6,
387      },
388   },
389};
390
391const struct transform *bifrost_nir_lower_algebraic_late_transforms[] = {
392   NULL,
393   NULL,
394   bifrost_nir_lower_algebraic_late_state2_xforms,
395   NULL,
396   bifrost_nir_lower_algebraic_late_state4_xforms,
397   NULL,
398   NULL,
399   bifrost_nir_lower_algebraic_late_state7_xforms,
400   bifrost_nir_lower_algebraic_late_state8_xforms,
401   bifrost_nir_lower_algebraic_late_state9_xforms,
402   bifrost_nir_lower_algebraic_late_state10_xforms,
403};
404
405const uint16_t bifrost_nir_lower_algebraic_late_transform_counts[] = {
406   0,
407   0,
408   (uint16_t)ARRAY_SIZE(bifrost_nir_lower_algebraic_late_state2_xforms),
409   0,
410   (uint16_t)ARRAY_SIZE(bifrost_nir_lower_algebraic_late_state4_xforms),
411   0,
412   0,
413   (uint16_t)ARRAY_SIZE(bifrost_nir_lower_algebraic_late_state7_xforms),
414   (uint16_t)ARRAY_SIZE(bifrost_nir_lower_algebraic_late_state8_xforms),
415   (uint16_t)ARRAY_SIZE(bifrost_nir_lower_algebraic_late_state9_xforms),
416   (uint16_t)ARRAY_SIZE(bifrost_nir_lower_algebraic_late_state10_xforms),
417};
418
419bool
420bifrost_nir_lower_algebraic_late(nir_shader *shader)
421{
422   bool progress = false;
423   bool condition_flags[1];
424   const nir_shader_compiler_options *options = shader->options;
425   const shader_info *info = &shader->info;
426   (void) options;
427   (void) info;
428
429   condition_flags[0] = true;
430
431   nir_foreach_function(function, shader) {
432      if (function->impl) {
433         progress |= nir_algebraic_impl(function->impl, condition_flags,
434                                        bifrost_nir_lower_algebraic_late_transforms,
435                                        bifrost_nir_lower_algebraic_late_transform_counts,
436                                        bifrost_nir_lower_algebraic_late_table);
437      }
438   }
439
440   return progress;
441}
442
443