1/*
2 * Copyright © 2020 Google LLC
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24/**
25 * @file
26 *
27 * Trims off the unused trailing components of SSA defs.
28 *
29 * Due to various optimization passes (or frontend implementations,
30 * particularly prog_to_nir), we may have instructions generating vectors
31 * whose components don't get read by any instruction. As it can be tricky
32 * to eliminate unused low components or channels in the middle of a writemask
33 * (you might need to increment some offset from a load_uniform, for example),
34 * it is trivial to just drop the trailing components. For vector ALU only used
35 * by ALU, this pass eliminates arbitrary channels and reswizzles the uses.
36 *
37 * This pass is probably only of use to vector backends -- scalar backends
38 * typically get unused def channel trimming by scalarizing and dead code
39 * elimination.
40 */
41
42#include "nir.h"
43#include "nir_builder.h"
44
45static bool
46shrink_dest_to_read_mask(nir_ssa_def *def)
47{
48   /* early out if there's nothing to do. */
49   if (def->num_components == 1)
50      return false;
51
52   /* don't remove any channels if used by an intrinsic */
53   nir_foreach_use(use_src, def) {
54      if (use_src->parent_instr->type == nir_instr_type_intrinsic)
55         return false;
56   }
57
58   unsigned mask = nir_ssa_def_components_read(def);
59   int last_bit = util_last_bit(mask);
60
61   /* If nothing was read, leave it up to DCE. */
62   if (!mask)
63      return false;
64
65   if (def->num_components > last_bit) {
66      def->num_components = last_bit;
67      return true;
68   }
69
70   return false;
71}
72
73static bool
74opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr)
75{
76   nir_ssa_def *def = &instr->dest.dest.ssa;
77
78   /* Nothing to shrink */
79   if (def->num_components == 1)
80      return false;
81
82   bool is_vec = false;
83   switch (instr->op) {
84      /* don't use nir_op_is_vec() as not all vector sizes are supported. */
85      case nir_op_vec4:
86      case nir_op_vec3:
87      case nir_op_vec2:
88         is_vec = true;
89         break;
90      default:
91         if (nir_op_infos[instr->op].output_size != 0)
92            return false;
93         break;
94   }
95
96   /* don't remove any channels if used by an intrinsic */
97   nir_foreach_use(use_src, def) {
98      if (use_src->parent_instr->type == nir_instr_type_intrinsic)
99         return false;
100   }
101
102   unsigned mask = nir_ssa_def_components_read(def);
103   unsigned last_bit = util_last_bit(mask);
104   unsigned num_components = util_bitcount(mask);
105
106   /* return, if there is nothing to do */
107   if (mask == 0 || num_components == def->num_components)
108      return false;
109
110   const bool is_bitfield_mask = last_bit == num_components;
111
112   if (is_vec) {
113      /* replace vecN with smaller version */
114      nir_ssa_def *srcs[NIR_MAX_VEC_COMPONENTS] = { 0 };
115      unsigned index = 0;
116      for (int i = 0; i < last_bit; i++) {
117         if ((mask >> i) & 0x1)
118            srcs[index++] = nir_ssa_for_alu_src(b, instr, i);
119      }
120      assert(index == num_components);
121      nir_ssa_def *new_vec = nir_vec(b, srcs, num_components);
122      nir_ssa_def_rewrite_uses(def, new_vec);
123      def = new_vec;
124   }
125
126   if (is_bitfield_mask) {
127      /* just reduce the number of components and return */
128      def->num_components = num_components;
129      instr->dest.write_mask = mask;
130      return true;
131   }
132
133   if (!is_vec) {
134      /* update sources */
135      for (int i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
136         unsigned index = 0;
137         for (int j = 0; j < last_bit; j++) {
138            if ((mask >> j) & 0x1)
139               instr->src[i].swizzle[index++] = instr->src[i].swizzle[j];
140         }
141         assert(index == num_components);
142      }
143
144      /* update dest */
145      def->num_components = num_components;
146      instr->dest.write_mask = BITFIELD_MASK(num_components);
147   }
148
149   /* compute new dest swizzles */
150   uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
151   unsigned index = 0;
152   for (int i = 0; i < last_bit; i++) {
153      if ((mask >> i) & 0x1)
154         reswizzle[i] = index++;
155   }
156   assert(index == num_components);
157
158   /* update uses */
159   nir_foreach_use(use_src, def) {
160      assert(use_src->parent_instr->type == nir_instr_type_alu);
161      nir_alu_src *alu_src = (nir_alu_src*)use_src;
162      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
163         alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]];
164   }
165
166   return true;
167}
168
169static bool
170opt_shrink_vectors_image_store(nir_builder *b, nir_intrinsic_instr *instr)
171{
172   enum pipe_format format;
173   if (instr->intrinsic == nir_intrinsic_image_deref_store) {
174      nir_deref_instr *deref = nir_src_as_deref(instr->src[0]);
175      format = nir_deref_instr_get_variable(deref)->data.image.format;
176   } else {
177      format = nir_intrinsic_format(instr);
178   }
179   if (format == PIPE_FORMAT_NONE)
180      return false;
181
182   unsigned components = util_format_get_nr_components(format);
183   if (components >= instr->num_components)
184      return false;
185
186   nir_ssa_def *data = nir_channels(b, instr->src[3].ssa, BITSET_MASK(components));
187   nir_instr_rewrite_src(&instr->instr, &instr->src[3], nir_src_for_ssa(data));
188   instr->num_components = components;
189
190   return true;
191}
192
193static bool
194opt_shrink_vectors_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, bool shrink_image_store)
195{
196   switch (instr->intrinsic) {
197   case nir_intrinsic_load_uniform:
198   case nir_intrinsic_load_ubo:
199   case nir_intrinsic_load_input:
200   case nir_intrinsic_load_input_vertex:
201   case nir_intrinsic_load_per_vertex_input:
202   case nir_intrinsic_load_interpolated_input:
203   case nir_intrinsic_load_ssbo:
204   case nir_intrinsic_load_push_constant:
205   case nir_intrinsic_load_constant:
206   case nir_intrinsic_load_shared:
207   case nir_intrinsic_load_global:
208   case nir_intrinsic_load_global_constant:
209   case nir_intrinsic_load_kernel_input:
210   case nir_intrinsic_load_scratch:
211   case nir_intrinsic_store_output:
212   case nir_intrinsic_store_per_vertex_output:
213   case nir_intrinsic_store_ssbo:
214   case nir_intrinsic_store_shared:
215   case nir_intrinsic_store_global:
216   case nir_intrinsic_store_scratch:
217      break;
218   case nir_intrinsic_bindless_image_store:
219   case nir_intrinsic_image_deref_store:
220   case nir_intrinsic_image_store:
221      return shrink_image_store && opt_shrink_vectors_image_store(b, instr);
222   default:
223      return false;
224   }
225
226   /* Must be a vectorized intrinsic that we can resize. */
227   assert(instr->num_components != 0);
228
229   if (nir_intrinsic_infos[instr->intrinsic].has_dest) {
230      /* loads: Trim the dest to the used channels */
231
232      if (shrink_dest_to_read_mask(&instr->dest.ssa)) {
233         instr->num_components = instr->dest.ssa.num_components;
234         return true;
235      }
236   } else {
237      /* Stores: trim the num_components stored according to the write
238       * mask.
239       */
240      unsigned write_mask = nir_intrinsic_write_mask(instr);
241      unsigned last_bit = util_last_bit(write_mask);
242      if (last_bit < instr->num_components && instr->src[0].is_ssa) {
243         nir_ssa_def *def = nir_channels(b, instr->src[0].ssa,
244                                         BITSET_MASK(last_bit));
245         nir_instr_rewrite_src(&instr->instr,
246                               &instr->src[0],
247                               nir_src_for_ssa(def));
248         instr->num_components = last_bit;
249
250         return true;
251      }
252   }
253
254   return false;
255}
256
257static bool
258opt_shrink_vectors_load_const(nir_load_const_instr *instr)
259{
260   return shrink_dest_to_read_mask(&instr->def);
261}
262
263static bool
264opt_shrink_vectors_ssa_undef(nir_ssa_undef_instr *instr)
265{
266   return shrink_dest_to_read_mask(&instr->def);
267}
268
269static bool
270opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr, bool shrink_image_store)
271{
272   b->cursor = nir_before_instr(instr);
273
274   switch (instr->type) {
275   case nir_instr_type_alu:
276      return opt_shrink_vectors_alu(b, nir_instr_as_alu(instr));
277
278   case nir_instr_type_intrinsic:
279      return opt_shrink_vectors_intrinsic(b, nir_instr_as_intrinsic(instr), shrink_image_store);
280
281   case nir_instr_type_load_const:
282      return opt_shrink_vectors_load_const(nir_instr_as_load_const(instr));
283
284   case nir_instr_type_ssa_undef:
285      return opt_shrink_vectors_ssa_undef(nir_instr_as_ssa_undef(instr));
286
287   default:
288      return false;
289   }
290
291   return true;
292}
293
294bool
295nir_opt_shrink_vectors(nir_shader *shader, bool shrink_image_store)
296{
297   bool progress = false;
298
299   nir_foreach_function(function, shader) {
300      if (!function->impl)
301         continue;
302
303      nir_builder b;
304      nir_builder_init(&b, function->impl);
305
306      nir_foreach_block_reverse(block, function->impl) {
307         nir_foreach_instr_reverse(instr, block) {
308            progress |= opt_shrink_vectors_instr(&b, instr, shrink_image_store);
309         }
310      }
311
312      if (progress) {
313         nir_metadata_preserve(function->impl,
314                               nir_metadata_block_index |
315                               nir_metadata_dominance);
316      } else {
317         nir_metadata_preserve(function->impl, nir_metadata_all);
318      }
319   }
320
321   return progress;
322}
323