1/*
2 * Copyright (C) 2020 Google, Inc.
3 * Copyright (C) 2021 Advanced Micro Devices, Inc.
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "nir.h"
26#include "nir_builder.h"
27
28/**
29 * Return the intrinsic if it matches the mask in "modes", else return NULL.
30 */
31static nir_intrinsic_instr *
32get_io_intrinsic(nir_instr *instr, nir_variable_mode modes,
33                 nir_variable_mode *out_mode)
34{
35   if (instr->type != nir_instr_type_intrinsic)
36      return NULL;
37
38   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
39
40   switch (intr->intrinsic) {
41   case nir_intrinsic_load_input:
42   case nir_intrinsic_load_input_vertex:
43   case nir_intrinsic_load_interpolated_input:
44   case nir_intrinsic_load_per_vertex_input:
45      *out_mode = nir_var_shader_in;
46      return modes & nir_var_shader_in ? intr : NULL;
47   case nir_intrinsic_load_output:
48   case nir_intrinsic_load_per_vertex_output:
49   case nir_intrinsic_store_output:
50   case nir_intrinsic_store_per_vertex_output:
51      *out_mode = nir_var_shader_out;
52      return modes & nir_var_shader_out ? intr : NULL;
53   default:
54      return NULL;
55   }
56}
57
58/**
59 * Recompute the IO "base" indices from scratch to remove holes or to fix
60 * incorrect base values due to changes in IO locations by using IO locations
61 * to assign new bases. The mapping from locations to bases becomes
62 * monotonically increasing.
63 */
64bool
65nir_recompute_io_bases(nir_function_impl *impl, nir_variable_mode modes)
66{
67   BITSET_DECLARE(inputs, NUM_TOTAL_VARYING_SLOTS);
68   BITSET_DECLARE(outputs, NUM_TOTAL_VARYING_SLOTS);
69   BITSET_ZERO(inputs);
70   BITSET_ZERO(outputs);
71
72   /* Gather the bitmasks of used locations. */
73   nir_foreach_block_safe (block, impl) {
74      nir_foreach_instr_safe (instr, block) {
75         nir_variable_mode mode;
76         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
77         if (!intr)
78            continue;
79
80         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
81         unsigned num_slots = sem.num_slots;
82         if (sem.medium_precision)
83            num_slots = (num_slots + sem.high_16bits + 1) / 2;
84
85         if (mode == nir_var_shader_in) {
86            for (unsigned i = 0; i < num_slots; i++)
87               BITSET_SET(inputs, sem.location + i);
88         } else if (!sem.dual_source_blend_index) {
89            for (unsigned i = 0; i < num_slots; i++)
90               BITSET_SET(outputs, sem.location + i);
91         }
92      }
93   }
94
95   /* Renumber bases. */
96   bool changed = false;
97
98   nir_foreach_block_safe (block, impl) {
99      nir_foreach_instr_safe (instr, block) {
100         nir_variable_mode mode;
101         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
102         if (!intr)
103            continue;
104
105         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
106         unsigned num_slots = sem.num_slots;
107         if (sem.medium_precision)
108            num_slots = (num_slots + sem.high_16bits + 1) / 2;
109
110         if (mode == nir_var_shader_in) {
111            nir_intrinsic_set_base(intr,
112                                   BITSET_PREFIX_SUM(inputs, sem.location));
113         } else if (sem.dual_source_blend_index) {
114            nir_intrinsic_set_base(intr,
115                                   BITSET_PREFIX_SUM(outputs, NUM_TOTAL_VARYING_SLOTS));
116         } else {
117            nir_intrinsic_set_base(intr,
118                                   BITSET_PREFIX_SUM(outputs, sem.location));
119         }
120         changed = true;
121      }
122   }
123
124   if (changed) {
125      nir_metadata_preserve(impl, nir_metadata_dominance |
126                                  nir_metadata_block_index);
127   } else {
128      nir_metadata_preserve(impl, nir_metadata_all);
129   }
130
131   return changed;
132}
133
134/**
135 * Lower mediump inputs and/or outputs to 16 bits.
136 *
137 * \param modes            Whether to lower inputs, outputs, or both.
138 * \param varying_mask     Determines which varyings to skip (VS inputs,
139 *    FS outputs, and patch varyings ignore this mask).
140 * \param use_16bit_slots  Remap lowered slots to* VARYING_SLOT_VARn_16BIT.
141 */
142bool
143nir_lower_mediump_io(nir_shader *nir, nir_variable_mode modes,
144                     uint64_t varying_mask, bool use_16bit_slots)
145{
146   bool changed = false;
147   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
148   assert(impl);
149
150   nir_builder b;
151   nir_builder_init(&b, impl);
152
153   nir_foreach_block_safe (block, impl) {
154      nir_foreach_instr_safe (instr, block) {
155         nir_variable_mode mode;
156         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
157         if (!intr)
158            continue;
159
160         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
161         nir_ssa_def *(*convert)(nir_builder *, nir_ssa_def *);
162         bool is_varying = !(nir->info.stage == MESA_SHADER_VERTEX &&
163                             mode == nir_var_shader_in) &&
164                           !(nir->info.stage == MESA_SHADER_FRAGMENT &&
165                             mode == nir_var_shader_out);
166
167         if (!sem.medium_precision ||
168             (is_varying && sem.location <= VARYING_SLOT_VAR31 &&
169              !(varying_mask & BITFIELD64_BIT(sem.location))))
170            continue; /* can't lower */
171
172         if (nir_intrinsic_has_src_type(intr)) {
173            /* Stores. */
174            nir_alu_type type = nir_intrinsic_src_type(intr);
175
176            switch (type) {
177            case nir_type_float32:
178               convert = nir_f2fmp;
179               break;
180            case nir_type_int32:
181            case nir_type_uint32:
182               convert = nir_i2imp;
183               break;
184            default:
185               continue; /* already lowered? */
186            }
187
188            /* Convert the 32-bit store into a 16-bit store. */
189            b.cursor = nir_before_instr(&intr->instr);
190            nir_instr_rewrite_src_ssa(&intr->instr, &intr->src[0],
191                                      convert(&b, intr->src[0].ssa));
192            nir_intrinsic_set_src_type(intr, (type & ~32) | 16);
193         } else {
194            /* Loads. */
195            nir_alu_type type = nir_intrinsic_dest_type(intr);
196
197            switch (type) {
198            case nir_type_float32:
199               convert = nir_f2f32;
200               break;
201            case nir_type_int32:
202               convert = nir_i2i32;
203               break;
204            case nir_type_uint32:
205               convert = nir_u2u32;
206               break;
207            default:
208               continue; /* already lowered? */
209            }
210
211            /* Convert the 32-bit load into a 16-bit load. */
212            b.cursor = nir_after_instr(&intr->instr);
213            intr->dest.ssa.bit_size = 16;
214            nir_intrinsic_set_dest_type(intr, (type & ~32) | 16);
215            nir_ssa_def *dst = convert(&b, &intr->dest.ssa);
216            nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, dst,
217                                           dst->parent_instr);
218         }
219
220         if (use_16bit_slots && is_varying &&
221             sem.location >= VARYING_SLOT_VAR0 &&
222             sem.location <= VARYING_SLOT_VAR31) {
223            unsigned index = sem.location - VARYING_SLOT_VAR0;
224
225            sem.location = VARYING_SLOT_VAR0_16BIT + index / 2;
226            sem.high_16bits = index % 2;
227            nir_intrinsic_set_io_semantics(intr, sem);
228         }
229         changed = true;
230      }
231   }
232
233   if (changed && use_16bit_slots)
234      nir_recompute_io_bases(impl, modes);
235
236   if (changed) {
237      nir_metadata_preserve(impl, nir_metadata_dominance |
238                                  nir_metadata_block_index);
239   } else {
240      nir_metadata_preserve(impl, nir_metadata_all);
241   }
242
243   return changed;
244}
245
246/**
247 * Set the mediump precision bit for those shader inputs and outputs that are
248 * set in the "modes" mask. Non-generic varyings (that GLES3 doesn't have)
249 * are ignored. The "types" mask can be (nir_type_float | nir_type_int), etc.
250 */
251bool
252nir_force_mediump_io(nir_shader *nir, nir_variable_mode modes,
253                     nir_alu_type types)
254{
255   bool changed = false;
256   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
257   assert(impl);
258
259   nir_builder b;
260   nir_builder_init(&b, impl);
261
262   nir_foreach_block_safe (block, impl) {
263      nir_foreach_instr_safe (instr, block) {
264         nir_variable_mode mode;
265         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
266         if (!intr)
267            continue;
268
269         nir_alu_type type;
270         if (nir_intrinsic_has_src_type(intr))
271            type = nir_intrinsic_src_type(intr);
272         else
273            type = nir_intrinsic_dest_type(intr);
274         if (!(type & types))
275            continue;
276
277         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
278
279         if (nir->info.stage == MESA_SHADER_FRAGMENT &&
280             mode == nir_var_shader_out) {
281            /* Only accept FS outputs. */
282            if (sem.location < FRAG_RESULT_DATA0 &&
283                sem.location != FRAG_RESULT_COLOR)
284               continue;
285         } else if (nir->info.stage == MESA_SHADER_VERTEX &&
286                    mode == nir_var_shader_in) {
287            /* Accept all VS inputs. */
288         } else {
289            /* Only accept generic varyings. */
290            if (sem.location < VARYING_SLOT_VAR0 ||
291                sem.location > VARYING_SLOT_VAR31)
292            continue;
293         }
294
295         sem.medium_precision = 1;
296         nir_intrinsic_set_io_semantics(intr, sem);
297         changed = true;
298      }
299   }
300
301   if (changed) {
302      nir_metadata_preserve(impl, nir_metadata_dominance |
303                                  nir_metadata_block_index);
304   } else {
305      nir_metadata_preserve(impl, nir_metadata_all);
306   }
307
308   return changed;
309}
310
311/**
312 * Remap 16-bit varying slots to the original 32-bit varying slots.
313 * This only changes IO semantics and bases.
314 */
315bool
316nir_unpack_16bit_varying_slots(nir_shader *nir, nir_variable_mode modes)
317{
318   bool changed = false;
319   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
320   assert(impl);
321
322   nir_foreach_block_safe (block, impl) {
323      nir_foreach_instr_safe (instr, block) {
324         nir_variable_mode mode;
325         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
326         if (!intr)
327            continue;
328
329         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
330
331         if (sem.location < VARYING_SLOT_VAR0_16BIT ||
332             sem.location > VARYING_SLOT_VAR15_16BIT)
333            continue;
334
335         sem.location = VARYING_SLOT_VAR0 +
336                        (sem.location - VARYING_SLOT_VAR0_16BIT) * 2 +
337                        sem.high_16bits;
338         sem.high_16bits = 0;
339         nir_intrinsic_set_io_semantics(intr, sem);
340         changed = true;
341      }
342   }
343
344   if (changed)
345      nir_recompute_io_bases(impl, modes);
346
347   if (changed) {
348      nir_metadata_preserve(impl, nir_metadata_dominance |
349                                  nir_metadata_block_index);
350   } else {
351      nir_metadata_preserve(impl, nir_metadata_all);
352   }
353
354   return changed;
355}
356
357static bool
358is_n_to_m_conversion(nir_instr *instr, unsigned n, nir_op m)
359{
360   if (instr->type != nir_instr_type_alu)
361      return false;
362
363   nir_alu_instr *alu = nir_instr_as_alu(instr);
364   return alu->op == m && alu->src[0].src.ssa->bit_size == n;
365}
366
367static bool
368is_f16_to_f32_conversion(nir_instr *instr)
369{
370   return is_n_to_m_conversion(instr, 16, nir_op_f2f32);
371}
372
373static bool
374is_f32_to_f16_conversion(nir_instr *instr)
375{
376   return is_n_to_m_conversion(instr, 32, nir_op_f2f16) ||
377          is_n_to_m_conversion(instr, 32, nir_op_f2f16_rtne) ||
378          is_n_to_m_conversion(instr, 32, nir_op_f2fmp);
379}
380
381static bool
382is_i16_to_i32_conversion(nir_instr *instr)
383{
384   return is_n_to_m_conversion(instr, 16, nir_op_i2i32);
385}
386
387static bool
388is_u16_to_u32_conversion(nir_instr *instr)
389{
390   return is_n_to_m_conversion(instr, 16, nir_op_u2u32);
391}
392
393static bool
394is_i32_to_i16_conversion(nir_instr *instr)
395{
396   return is_n_to_m_conversion(instr, 32, nir_op_i2i16);
397}
398
399static void
400replace_with_mov(nir_builder *b, nir_instr *instr, nir_src *src,
401                 nir_alu_instr *alu)
402{
403   nir_ssa_def *mov = nir_mov_alu(b, alu->src[0],
404                                  nir_dest_num_components(alu->dest.dest));
405   assert(!alu->dest.saturate);
406   nir_instr_rewrite_src_ssa(instr, src, mov);
407}
408
409/**
410 * If texture source operands use f16->f32 conversions or return values are
411 * followed by f16->f32 or f32->f16, remove those conversions. This benefits
412 * drivers that have texture opcodes that can accept and return 16-bit types.
413 *
414 * "tex_src_types" is a mask of nir_tex_src_* operands that should be handled.
415 * It's always done for the destination.
416 *
417 * This should be run after late algebraic optimizations.
418 * Copy propagation and DCE should be run after this.
419 */
420bool
421nir_fold_16bit_sampler_conversions(nir_shader *nir,
422                                   unsigned tex_src_types)
423{
424   bool changed = false;
425   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
426   assert(impl);
427
428   nir_builder b;
429   nir_builder_init(&b, impl);
430
431   nir_foreach_block_safe (block, impl) {
432      nir_foreach_instr_safe (instr, block) {
433         if (instr->type != nir_instr_type_tex)
434            continue;
435
436         nir_tex_instr *tex = nir_instr_as_tex(instr);
437         nir_instr *src;
438         nir_alu_instr *src_alu;
439
440         /* Skip because AMD doesn't support 16-bit types with these. */
441         if ((tex->op == nir_texop_txs ||
442              tex->op == nir_texop_query_levels) ||
443             tex->sampler_dim == GLSL_SAMPLER_DIM_CUBE)
444            continue;
445
446         /* Optimize source operands. */
447         for (unsigned i = 0; i < tex->num_srcs; i++) {
448            /* Filter out sources that should be ignored. */
449            if (!(BITFIELD_BIT(tex->src[i].src_type) & tex_src_types))
450               continue;
451
452            src = tex->src[i].src.ssa->parent_instr;
453            if (src->type != nir_instr_type_alu)
454               continue;
455
456            src_alu = nir_instr_as_alu(src);
457            b.cursor = nir_before_instr(src);
458
459            if (src_alu->op == nir_op_mov) {
460               assert(!"The IR shouldn't contain any movs to make this pass"
461                       " effective.");
462               continue;
463            }
464
465            /* Handle vector sources that are made of scalar instructions. */
466            if (nir_op_is_vec(src_alu->op)) {
467               /* See if the vector is made of f16->f32 opcodes. */
468               unsigned num = nir_dest_num_components(src_alu->dest.dest);
469               bool is_f16_to_f32 = true;
470               bool is_u16_to_u32 = true;
471
472               for (unsigned comp = 0; comp < num; comp++) {
473                  nir_instr *instr = src_alu->src[comp].src.ssa->parent_instr;
474                  is_f16_to_f32 &= is_f16_to_f32_conversion(instr);
475                  /* Zero-extension (u16) and sign-extension (i16) have
476                   * the same behavior here - txf returns 0 if bit 15 is set
477                   * because it's out of bounds and the higher bits don't
478                   * matter.
479                   */
480                  is_u16_to_u32 &= is_u16_to_u32_conversion(instr) ||
481                                   is_i16_to_i32_conversion(instr);
482               }
483
484               if (!is_f16_to_f32 && !is_u16_to_u32)
485                  continue;
486
487               nir_alu_instr *new_vec = nir_alu_instr_clone(nir, src_alu);
488               nir_instr_insert_after(&src_alu->instr, &new_vec->instr);
489
490               /* Replace conversions with mov. */
491               for (unsigned comp = 0; comp < num; comp++) {
492                  nir_instr *instr = new_vec->src[comp].src.ssa->parent_instr;
493                  replace_with_mov(&b, &new_vec->instr,
494                                   &new_vec->src[comp].src,
495                                   nir_instr_as_alu(instr));
496               }
497
498               new_vec->dest.dest.ssa.bit_size =
499                  new_vec->src[0].src.ssa->bit_size;
500               nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[i].src,
501                                         &new_vec->dest.dest.ssa);
502               changed = true;
503            } else if (is_f16_to_f32_conversion(&src_alu->instr) ||
504                       is_u16_to_u32_conversion(&src_alu->instr) ||
505                       is_i16_to_i32_conversion(&src_alu->instr)) {
506               /* Handle scalar sources. */
507               replace_with_mov(&b, &tex->instr, &tex->src[i].src, src_alu);
508               changed = true;
509            }
510         }
511
512         /* Optimize the destination. */
513         bool is_f16_to_f32 = true;
514         bool is_f32_to_f16 = true;
515         bool is_i16_to_i32 = true;
516         bool is_i32_to_i16 = true; /* same behavior for int and uint */
517         bool is_u16_to_u32 = true;
518
519         nir_foreach_use(use, &tex->dest.ssa) {
520            is_f16_to_f32 &= is_f16_to_f32_conversion(use->parent_instr);
521            is_f32_to_f16 &= is_f32_to_f16_conversion(use->parent_instr);
522            is_i16_to_i32 &= is_i16_to_i32_conversion(use->parent_instr);
523            is_i32_to_i16 &= is_i32_to_i16_conversion(use->parent_instr);
524            is_u16_to_u32 &= is_u16_to_u32_conversion(use->parent_instr);
525         }
526
527         if (is_f16_to_f32 || is_f32_to_f16 || is_i16_to_i32 ||
528             is_i32_to_i16 || is_u16_to_u32) {
529            /* All uses are the same conversions. Replace them with mov. */
530            nir_foreach_use(use, &tex->dest.ssa) {
531               nir_alu_instr *conv = nir_instr_as_alu(use->parent_instr);
532               conv->op = nir_op_mov;
533               tex->dest.ssa.bit_size = conv->dest.dest.ssa.bit_size;
534               tex->dest_type = (tex->dest_type & (~16 & ~32 & ~64)) |
535                                conv->dest.dest.ssa.bit_size;
536            }
537            changed = true;
538         }
539      }
540   }
541
542   if (changed) {
543      nir_metadata_preserve(impl, nir_metadata_dominance |
544                                  nir_metadata_block_index);
545   } else {
546      nir_metadata_preserve(impl, nir_metadata_all);
547   }
548
549   return changed;
550}
551
552/**
553 * Fix types of source operands of texture opcodes according to
554 * the constraints by inserting the appropriate conversion opcodes.
555 *
556 * For example, if the type of derivatives must be equal to texture
557 * coordinates and the type of the texture bias must be 32-bit, there
558 * will be 2 constraints describing that.
559 */
560bool
561nir_legalize_16bit_sampler_srcs(nir_shader *nir,
562                                nir_tex_src_type_constraints constraints)
563{
564   bool changed = false;
565   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
566   assert(impl);
567
568   nir_builder b;
569   nir_builder_init(&b, impl);
570
571   nir_foreach_block_safe (block, impl) {
572      nir_foreach_instr_safe (instr, block) {
573         if (instr->type != nir_instr_type_tex)
574            continue;
575
576         nir_tex_instr *tex = nir_instr_as_tex(instr);
577         int8_t map[nir_num_tex_src_types];
578         memset(map, -1, sizeof(map));
579
580         /* Create a mapping from src_type to src[i]. */
581         for (unsigned i = 0; i < tex->num_srcs; i++)
582            map[tex->src[i].src_type] = i;
583
584         /* Legalize src types. */
585         for (unsigned i = 0; i < tex->num_srcs; i++) {
586            nir_tex_src_type_constraint c = constraints[tex->src[i].src_type];
587
588            if (!c.legalize_type)
589               continue;
590
591            /* Determine the required bit size for the src. */
592            unsigned bit_size;
593            if (c.bit_size) {
594               bit_size = c.bit_size;
595            } else {
596               if (map[c.match_src] == -1)
597                  continue; /* e.g. txs */
598
599               bit_size = tex->src[map[c.match_src]].src.ssa->bit_size;
600            }
601
602            /* Check if the type is legal. */
603            if (bit_size == tex->src[i].src.ssa->bit_size)
604               continue;
605
606            /* Fix the bit size. */
607            bool is_sint = i == nir_tex_src_offset;
608            bool is_uint = !is_sint &&
609                           (tex->op == nir_texop_txf ||
610                            tex->op == nir_texop_txf_ms ||
611                            tex->op == nir_texop_txs ||
612                            tex->op == nir_texop_samples_identical);
613            nir_ssa_def *(*convert)(nir_builder *, nir_ssa_def *);
614
615            switch (bit_size) {
616            case 16:
617               convert = is_sint ? nir_i2i16 :
618                         is_uint ? nir_u2u16 : nir_f2f16;
619               break;
620            case 32:
621               convert = is_sint ? nir_i2i32 :
622                         is_uint ? nir_u2u32 : nir_f2f32;
623               break;
624            default:
625               assert(!"unexpected bit size");
626               continue;
627            }
628
629            b.cursor = nir_before_instr(&tex->instr);
630            nir_ssa_def *conv =
631               convert(&b, nir_ssa_for_src(&b, tex->src[i].src,
632                                           tex->src[i].src.ssa->num_components));
633            nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[i].src, conv);
634            changed = true;
635         }
636      }
637   }
638
639   if (changed) {
640      nir_metadata_preserve(impl, nir_metadata_dominance |
641                                  nir_metadata_block_index);
642   } else {
643      nir_metadata_preserve(impl, nir_metadata_all);
644   }
645
646   return changed;
647}
648