1/*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "nir.h"
25#include "nir_builder.h"
26#include "nir_deref.h"
27
28static bool
29index_ssa_def_cb(nir_ssa_def *def, void *state)
30{
31   unsigned *index = (unsigned *) state;
32   def->index = (*index)++;
33
34   return true;
35}
36
37static nir_deref_instr *
38get_deref_for_load_src(nir_src src, unsigned first_valid_load)
39{
40   nir_intrinsic_instr *load = nir_src_as_intrinsic(src);
41   if (load == NULL || load->intrinsic != nir_intrinsic_load_deref)
42      return NULL;
43
44   if (load->dest.ssa.index < first_valid_load)
45      return NULL;
46
47   return nir_src_as_deref(load->src[0]);
48}
49
50struct match_state {
51   /* Index into the array of the last copy or -1 for no ongoing copy. */
52   unsigned next_array_idx;
53
54   /* Length of the array we're copying */
55   unsigned array_len;
56
57   /* Index into the deref path to the array we think is being copied */
58   int src_deref_array_idx;
59   int dst_deref_array_idx;
60
61   /* Deref paths of the first load/store pair or copy */
62   nir_deref_path first_src_path;
63   nir_deref_path first_dst_path;
64};
65
66static void
67match_state_init(struct match_state *state)
68{
69   state->next_array_idx = 0;
70   state->array_len = 0;
71   state->src_deref_array_idx = -1;
72   state->dst_deref_array_idx = -1;
73}
74
75static void
76match_state_finish(struct match_state *state)
77{
78   if (state->next_array_idx > 0) {
79      nir_deref_path_finish(&state->first_src_path);
80      nir_deref_path_finish(&state->first_dst_path);
81   }
82}
83
84static void
85match_state_reset(struct match_state *state)
86{
87   match_state_finish(state);
88   match_state_init(state);
89}
90
91static bool
92try_match_deref(nir_deref_path *base_path, int *path_array_idx,
93                nir_deref_instr *deref, int arr_idx, void *mem_ctx)
94{
95   nir_deref_path deref_path;
96   nir_deref_path_init(&deref_path, deref, mem_ctx);
97
98   bool found = false;
99   for (int i = 0; ; i++) {
100      nir_deref_instr *b = base_path->path[i];
101      nir_deref_instr *d = deref_path.path[i];
102      /* They have to be the same length */
103      if ((b == NULL) != (d == NULL))
104         goto fail;
105
106      if (b == NULL)
107         break;
108
109      /* This can happen if one is a deref_array and the other a wildcard */
110      if (b->deref_type != d->deref_type)
111         goto fail;
112
113      switch (b->deref_type) {
114      case nir_deref_type_var:
115         if (b->var != d->var)
116            goto fail;
117         continue;
118
119      case nir_deref_type_array:
120         assert(b->arr.index.is_ssa && d->arr.index.is_ssa);
121         const bool const_b_idx = nir_src_is_const(b->arr.index);
122         const bool const_d_idx = nir_src_is_const(d->arr.index);
123         const unsigned b_idx = const_b_idx ? nir_src_as_uint(b->arr.index) : 0;
124         const unsigned d_idx = const_d_idx ? nir_src_as_uint(d->arr.index) : 0;
125
126         /* If we don't have an index into the path yet or if this entry in
127          * the path is at the array index, see if this is a candidate.  We're
128          * looking for an index which is zero in the base deref and arr_idx
129          * in the search deref.
130          */
131         if ((*path_array_idx < 0 || *path_array_idx == i) &&
132             const_b_idx && b_idx == 0 &&
133             const_d_idx && d_idx == arr_idx) {
134            *path_array_idx = i;
135            continue;
136         }
137
138         /* We're at the array index but not a candidate */
139         if (*path_array_idx == i)
140            goto fail;
141
142         /* If we're not the path array index, we must match exactly.  We
143          * could probably just compare SSA values and trust in copy
144          * propagation but doing it ourselves means this pass can run a bit
145          * earlier.
146          */
147         if (b->arr.index.ssa == d->arr.index.ssa ||
148             (const_b_idx && const_d_idx && b_idx == d_idx))
149            continue;
150
151         goto fail;
152
153      case nir_deref_type_array_wildcard:
154         continue;
155
156      case nir_deref_type_struct:
157         if (b->strct.index != d->strct.index)
158            goto fail;
159         continue;
160
161      default:
162         unreachable("Invalid deref type in a path");
163      }
164   }
165
166   /* If we got here without failing, we've matched.  However, it isn't an
167    * array match unless we found an altered array index.
168    */
169   found = *path_array_idx > 0;
170
171fail:
172   nir_deref_path_finish(&deref_path);
173   return found;
174}
175
176static nir_deref_instr *
177build_wildcard_deref(nir_builder *b, nir_deref_path *path,
178                     unsigned wildcard_idx)
179{
180   assert(path->path[wildcard_idx]->deref_type == nir_deref_type_array);
181
182   nir_deref_instr *tail =
183      nir_build_deref_array_wildcard(b, path->path[wildcard_idx - 1]);
184
185   for (unsigned i = wildcard_idx + 1; path->path[i]; i++)
186      tail = nir_build_deref_follower(b, tail, path->path[i]);
187
188   return tail;
189}
190
191static bool
192opt_find_array_copies_block(nir_builder *b, nir_block *block,
193                            unsigned *num_ssa_defs, void *mem_ctx)
194{
195   bool progress = false;
196
197   struct match_state s;
198   match_state_init(&s);
199
200   nir_variable *dst_var = NULL;
201   unsigned prev_dst_var_last_write = *num_ssa_defs;
202   unsigned dst_var_last_write = *num_ssa_defs;
203
204   nir_foreach_instr(instr, block) {
205      /* Index the SSA defs before we do anything else. */
206      nir_foreach_ssa_def(instr, index_ssa_def_cb, num_ssa_defs);
207
208      if (instr->type != nir_instr_type_intrinsic)
209         continue;
210
211      nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
212      if (intrin->intrinsic != nir_intrinsic_copy_deref &&
213          intrin->intrinsic != nir_intrinsic_store_deref)
214         continue;
215
216      nir_deref_instr *dst_deref = nir_src_as_deref(intrin->src[0]);
217
218      /* The destination must be local.  If we see a non-local store, we
219       * continue on because it won't affect local stores or read-only
220       * variables.
221       */
222      if (dst_deref->mode != nir_var_function_temp)
223         continue;
224
225      /* We keep track of the SSA indices where the two last-written
226       * variables are written.  The prev_dst_var_last_write tells us when
227       * the last store_deref to something other than dst happened.  If the
228       * SSA def index from a load is greater than or equal to this number
229       * then we know it happened afterwards and no writes to anything other
230       * than dst occur between the load and the current instruction.
231       */
232      if (nir_deref_instr_get_variable(dst_deref) != dst_var) {
233         prev_dst_var_last_write = dst_var_last_write;
234         dst_var = nir_deref_instr_get_variable(dst_deref);
235      }
236      dst_var_last_write = *num_ssa_defs;
237
238      /* If it's a full variable store or copy, reset.  This will trigger
239       * eventually because we'll fail to match an array element.  However,
240       * it's a cheap early-exit.
241       */
242      if (dst_deref->deref_type == nir_deref_type_var)
243         goto reset;
244
245      nir_deref_instr *src_deref;
246      if (intrin->intrinsic == nir_intrinsic_copy_deref) {
247         src_deref = nir_src_as_deref(intrin->src[1]);
248      } else {
249         assert(intrin->intrinsic == nir_intrinsic_store_deref);
250         src_deref = get_deref_for_load_src(intrin->src[1],
251                                            prev_dst_var_last_write);
252
253         /* We can only handle full writes */
254         if (nir_intrinsic_write_mask(intrin) !=
255             (1 << glsl_get_components(dst_deref->type)) - 1)
256            goto reset;
257      }
258
259      /* If we didn't find a valid src, then we have an unknown store and it
260       * could mess things up.
261       */
262      if (src_deref == NULL)
263         goto reset;
264
265      /* The source must be either local or something that's guaranteed to be
266       * read-only.
267       */
268      const nir_variable_mode read_only_modes =
269         nir_var_shader_in | nir_var_uniform | nir_var_system_value;
270      if (!(src_deref->mode & (nir_var_function_temp | read_only_modes)))
271         goto reset;
272
273      /* If we don't yet have an active copy, then make this instruction the
274       * active copy.
275       */
276      if (s.next_array_idx == 0) {
277         /* We can't combine a copy if there is any chance the source and
278          * destination will end up aliasing.  Just bail if they're the same
279          * variable.
280          */
281         if (nir_deref_instr_get_variable(src_deref) == dst_var)
282            goto reset;
283
284         /* The load/store pair is enough to guarantee the same bit size and
285          * number of components but a copy_var requires the actual types to
286          * match.
287          */
288         if (dst_deref->type != src_deref->type)
289            continue;
290
291         /* The first time we see a store, we don't know which array in the
292          * deref path is the one being copied so we just record the paths
293          * as-is and continue.  On the next iteration, it will try to match
294          * based on which array index changed.
295          */
296         nir_deref_path_init(&s.first_dst_path, dst_deref, mem_ctx);
297         nir_deref_path_init(&s.first_src_path, src_deref, mem_ctx);
298         s.next_array_idx = 1;
299         continue;
300      }
301
302      if (!try_match_deref(&s.first_dst_path, &s.dst_deref_array_idx,
303                           dst_deref, s.next_array_idx, mem_ctx) ||
304          !try_match_deref(&s.first_src_path, &s.src_deref_array_idx,
305                           src_deref, s.next_array_idx, mem_ctx))
306         goto reset;
307
308      if (s.next_array_idx == 1) {
309         /* This is our first non-trivial match.  We now have indices into
310          * the search paths so we can do a couple more checks.
311          */
312         assert(s.dst_deref_array_idx > 0 && s.src_deref_array_idx > 0);
313         const struct glsl_type *dst_arr_type =
314            s.first_dst_path.path[s.dst_deref_array_idx - 1]->type;
315         const struct glsl_type *src_arr_type =
316            s.first_src_path.path[s.src_deref_array_idx - 1]->type;
317
318         assert(glsl_type_is_array(dst_arr_type) ||
319                glsl_type_is_matrix(dst_arr_type));
320         assert(glsl_type_is_array(src_arr_type) ||
321                glsl_type_is_matrix(src_arr_type));
322
323         /* They must be the same length */
324         s.array_len = glsl_get_length(dst_arr_type);
325         if (s.array_len != glsl_get_length(src_arr_type))
326            goto reset;
327      }
328
329      s.next_array_idx++;
330
331      if (s.next_array_idx == s.array_len) {
332         /* Hooray, We found a copy! */
333         b->cursor = nir_after_instr(instr);
334         nir_copy_deref(b, build_wildcard_deref(b, &s.first_dst_path,
335                                                s.dst_deref_array_idx),
336                           build_wildcard_deref(b, &s.first_src_path,
337                                                s.src_deref_array_idx));
338         match_state_reset(&s);
339         progress = true;
340      }
341
342      continue;
343
344   reset:
345      match_state_reset(&s);
346   }
347
348   return progress;
349}
350
351static bool
352opt_find_array_copies_impl(nir_function_impl *impl)
353{
354   nir_builder b;
355   nir_builder_init(&b, impl);
356
357   bool progress = false;
358
359   void *mem_ctx = ralloc_context(NULL);
360
361   /* We re-index the SSA defs as we go; it makes it easier to handle
362    * resetting the state machine.
363    */
364   unsigned num_ssa_defs = 0;
365
366   nir_foreach_block(block, impl) {
367      if (opt_find_array_copies_block(&b, block, &num_ssa_defs, mem_ctx))
368         progress = true;
369   }
370
371   impl->ssa_alloc = num_ssa_defs;
372
373   ralloc_free(mem_ctx);
374
375   if (progress) {
376      nir_metadata_preserve(impl, nir_metadata_block_index |
377                                  nir_metadata_dominance);
378   }
379
380   return progress;
381}
382
383/**
384 * This peephole optimization looks for a series of load/store_deref or
385 * copy_deref instructions that copy an array from one variable to another and
386 * turns it into a copy_deref that copies the entire array.  The pattern it
387 * looks for is extremely specific but it's good enough to pick up on the
388 * input array copies in DXVK and should also be able to pick up the sequence
389 * generated by spirv_to_nir for a OpLoad of a large composite followed by
390 * OpStore.
391 *
392 * TODO: Use a hash table approach to support out-of-order and interleaved
393 * copies.
394 */
395bool
396nir_opt_find_array_copies(nir_shader *shader)
397{
398   bool progress = false;
399
400   nir_foreach_function(function, shader) {
401      if (function->impl && opt_find_array_copies_impl(function->impl))
402         progress = true;
403   }
404
405   return progress;
406}
407