1/*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "nir.h"
25#include "nir_builder.h"
26#include "nir_deref.h"
27#include "nir_vla.h"
28
29#include "util/set.h"
30#include "util/u_math.h"
31
32static struct set *
33get_complex_used_vars(nir_shader *shader, void *mem_ctx)
34{
35   struct set *complex_vars = _mesa_pointer_set_create(mem_ctx);
36
37   nir_foreach_function(function, shader) {
38      if (!function->impl)
39         continue;
40
41      nir_foreach_block(block, function->impl) {
42         nir_foreach_instr(instr, block) {
43            if (instr->type != nir_instr_type_deref)
44               continue;
45
46            nir_deref_instr *deref = nir_instr_as_deref(instr);
47
48            /* We only need to consider var derefs because
49             * nir_deref_instr_has_complex_use is recursive.
50             */
51            if (deref->deref_type == nir_deref_type_var &&
52                nir_deref_instr_has_complex_use(deref))
53               _mesa_set_add(complex_vars, deref->var);
54         }
55      }
56   }
57
58   return complex_vars;
59}
60
61struct split_var_state {
62   void *mem_ctx;
63
64   nir_shader *shader;
65   nir_function_impl *impl;
66
67   nir_variable *base_var;
68};
69
70struct field {
71   struct field *parent;
72
73   const struct glsl_type *type;
74
75   unsigned num_fields;
76   struct field *fields;
77
78   nir_variable *var;
79};
80
81static const struct glsl_type *
82wrap_type_in_array(const struct glsl_type *type,
83                   const struct glsl_type *array_type)
84{
85   if (!glsl_type_is_array(array_type))
86      return type;
87
88   const struct glsl_type *elem_type =
89      wrap_type_in_array(type, glsl_get_array_element(array_type));
90   assert(glsl_get_explicit_stride(array_type) == 0);
91   return glsl_array_type(elem_type, glsl_get_length(array_type), 0);
92}
93
94static int
95num_array_levels_in_array_of_vector_type(const struct glsl_type *type)
96{
97   int num_levels = 0;
98   while (true) {
99      if (glsl_type_is_array_or_matrix(type)) {
100         num_levels++;
101         type = glsl_get_array_element(type);
102      } else if (glsl_type_is_vector_or_scalar(type)) {
103         return num_levels;
104      } else {
105         /* Not an array of vectors */
106         return -1;
107      }
108   }
109}
110
111static void
112init_field_for_type(struct field *field, struct field *parent,
113                    const struct glsl_type *type,
114                    const char *name,
115                    struct split_var_state *state)
116{
117   *field = (struct field) {
118      .parent = parent,
119      .type = type,
120   };
121
122   const struct glsl_type *struct_type = glsl_without_array(type);
123   if (glsl_type_is_struct_or_ifc(struct_type)) {
124      field->num_fields = glsl_get_length(struct_type),
125      field->fields = ralloc_array(state->mem_ctx, struct field,
126                                   field->num_fields);
127      for (unsigned i = 0; i < field->num_fields; i++) {
128         char *field_name = NULL;
129         if (name) {
130            field_name = ralloc_asprintf(state->mem_ctx, "%s_%s", name,
131                                         glsl_get_struct_elem_name(struct_type, i));
132         } else {
133            field_name = ralloc_asprintf(state->mem_ctx, "{unnamed %s}_%s",
134                                         glsl_get_type_name(struct_type),
135                                         glsl_get_struct_elem_name(struct_type, i));
136         }
137         init_field_for_type(&field->fields[i], field,
138                             glsl_get_struct_field(struct_type, i),
139                             field_name, state);
140      }
141   } else {
142      const struct glsl_type *var_type = type;
143      for (struct field *f = field->parent; f; f = f->parent)
144         var_type = wrap_type_in_array(var_type, f->type);
145
146      nir_variable_mode mode = state->base_var->data.mode;
147      if (mode == nir_var_function_temp) {
148         field->var = nir_local_variable_create(state->impl, var_type, name);
149      } else {
150         field->var = nir_variable_create(state->shader, mode, var_type, name);
151      }
152   }
153}
154
155static bool
156split_var_list_structs(nir_shader *shader,
157                       nir_function_impl *impl,
158                       struct exec_list *vars,
159                       nir_variable_mode mode,
160                       struct hash_table *var_field_map,
161                       struct set **complex_vars,
162                       void *mem_ctx)
163{
164   struct split_var_state state = {
165      .mem_ctx = mem_ctx,
166      .shader = shader,
167      .impl = impl,
168   };
169
170   struct exec_list split_vars;
171   exec_list_make_empty(&split_vars);
172
173   /* To avoid list confusion (we'll be adding things as we split variables),
174    * pull all of the variables we plan to split off of the list
175    */
176   nir_foreach_variable_in_list_safe(var, vars) {
177      if (var->data.mode != mode)
178         continue;
179
180      if (!glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
181         continue;
182
183      if (*complex_vars == NULL)
184         *complex_vars = get_complex_used_vars(shader, mem_ctx);
185
186      /* We can't split a variable that's referenced with deref that has any
187       * sort of complex usage.
188       */
189      if (_mesa_set_search(*complex_vars, var))
190         continue;
191
192      exec_node_remove(&var->node);
193      exec_list_push_tail(&split_vars, &var->node);
194   }
195
196   nir_foreach_variable_in_list(var, &split_vars) {
197      state.base_var = var;
198
199      struct field *root_field = ralloc(mem_ctx, struct field);
200      init_field_for_type(root_field, NULL, var->type, var->name, &state);
201      _mesa_hash_table_insert(var_field_map, var, root_field);
202   }
203
204   return !exec_list_is_empty(&split_vars);
205}
206
207static void
208split_struct_derefs_impl(nir_function_impl *impl,
209                         struct hash_table *var_field_map,
210                         nir_variable_mode modes,
211                         void *mem_ctx)
212{
213   nir_builder b;
214   nir_builder_init(&b, impl);
215
216   nir_foreach_block(block, impl) {
217      nir_foreach_instr_safe(instr, block) {
218         if (instr->type != nir_instr_type_deref)
219            continue;
220
221         nir_deref_instr *deref = nir_instr_as_deref(instr);
222         if (!nir_deref_mode_may_be(deref, modes))
223            continue;
224
225         /* Clean up any dead derefs we find lying around.  They may refer to
226          * variables we're planning to split.
227          */
228         if (nir_deref_instr_remove_if_unused(deref))
229            continue;
230
231         if (!glsl_type_is_vector_or_scalar(deref->type))
232            continue;
233
234         nir_variable *base_var = nir_deref_instr_get_variable(deref);
235         /* If we can't chase back to the variable, then we're a complex use.
236          * This should have been detected by get_complex_used_vars() and the
237          * variable should not have been split.  However, we have no way of
238          * knowing that here, so we just have to trust it.
239          */
240         if (base_var == NULL)
241            continue;
242
243         struct hash_entry *entry =
244            _mesa_hash_table_search(var_field_map, base_var);
245         if (!entry)
246            continue;
247
248         struct field *root_field = entry->data;
249
250         nir_deref_path path;
251         nir_deref_path_init(&path, deref, mem_ctx);
252
253         struct field *tail_field = root_field;
254         for (unsigned i = 0; path.path[i]; i++) {
255            if (path.path[i]->deref_type != nir_deref_type_struct)
256               continue;
257
258            assert(i > 0);
259            assert(glsl_type_is_struct_or_ifc(path.path[i - 1]->type));
260            assert(path.path[i - 1]->type ==
261                   glsl_without_array(tail_field->type));
262
263            tail_field = &tail_field->fields[path.path[i]->strct.index];
264         }
265         nir_variable *split_var = tail_field->var;
266
267         nir_deref_instr *new_deref = NULL;
268         for (unsigned i = 0; path.path[i]; i++) {
269            nir_deref_instr *p = path.path[i];
270            b.cursor = nir_after_instr(&p->instr);
271
272            switch (p->deref_type) {
273            case nir_deref_type_var:
274               assert(new_deref == NULL);
275               new_deref = nir_build_deref_var(&b, split_var);
276               break;
277
278            case nir_deref_type_array:
279            case nir_deref_type_array_wildcard:
280               new_deref = nir_build_deref_follower(&b, new_deref, p);
281               break;
282
283            case nir_deref_type_struct:
284               /* Nothing to do; we're splitting structs */
285               break;
286
287            default:
288               unreachable("Invalid deref type in path");
289            }
290         }
291
292         assert(new_deref->type == deref->type);
293         nir_ssa_def_rewrite_uses(&deref->dest.ssa,
294                                  &new_deref->dest.ssa);
295         nir_deref_instr_remove_if_unused(deref);
296      }
297   }
298}
299
300/** A pass for splitting structs into multiple variables
301 *
302 * This pass splits arrays of structs into multiple variables, one for each
303 * (possibly nested) structure member.  After this pass completes, no
304 * variables of the given mode will contain a struct type.
305 */
306bool
307nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
308{
309   void *mem_ctx = ralloc_context(NULL);
310   struct hash_table *var_field_map =
311      _mesa_pointer_hash_table_create(mem_ctx);
312   struct set *complex_vars = NULL;
313
314   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
315
316   bool has_global_splits = false;
317   if (modes & nir_var_shader_temp) {
318      has_global_splits = split_var_list_structs(shader, NULL,
319                                                 &shader->variables,
320                                                 nir_var_shader_temp,
321                                                 var_field_map,
322                                                 &complex_vars,
323                                                 mem_ctx);
324   }
325
326   bool progress = false;
327   nir_foreach_function(function, shader) {
328      if (!function->impl)
329         continue;
330
331      bool has_local_splits = false;
332      if (modes & nir_var_function_temp) {
333         has_local_splits = split_var_list_structs(shader, function->impl,
334                                                   &function->impl->locals,
335                                                   nir_var_function_temp,
336                                                   var_field_map,
337                                                   &complex_vars,
338                                                   mem_ctx);
339      }
340
341      if (has_global_splits || has_local_splits) {
342         split_struct_derefs_impl(function->impl, var_field_map,
343                                  modes, mem_ctx);
344
345         nir_metadata_preserve(function->impl, nir_metadata_block_index |
346                                               nir_metadata_dominance);
347         progress = true;
348      } else {
349         nir_metadata_preserve(function->impl, nir_metadata_all);
350      }
351   }
352
353   ralloc_free(mem_ctx);
354
355   return progress;
356}
357
358struct array_level_info {
359   unsigned array_len;
360   bool split;
361};
362
363struct array_split {
364   /* Only set if this is the tail end of the splitting */
365   nir_variable *var;
366
367   unsigned num_splits;
368   struct array_split *splits;
369};
370
371struct array_var_info {
372   nir_variable *base_var;
373
374   const struct glsl_type *split_var_type;
375
376   bool split_var;
377   struct array_split root_split;
378
379   unsigned num_levels;
380   struct array_level_info levels[0];
381};
382
383static bool
384init_var_list_array_infos(nir_shader *shader,
385                          struct exec_list *vars,
386                          nir_variable_mode mode,
387                          struct hash_table *var_info_map,
388                          struct set **complex_vars,
389                          void *mem_ctx)
390{
391   bool has_array = false;
392
393   nir_foreach_variable_in_list(var, vars) {
394      if (var->data.mode != mode)
395         continue;
396
397      int num_levels = num_array_levels_in_array_of_vector_type(var->type);
398      if (num_levels <= 0)
399         continue;
400
401      if (*complex_vars == NULL)
402         *complex_vars = get_complex_used_vars(shader, mem_ctx);
403
404      /* We can't split a variable that's referenced with deref that has any
405       * sort of complex usage.
406       */
407      if (_mesa_set_search(*complex_vars, var))
408         continue;
409
410      struct array_var_info *info =
411         rzalloc_size(mem_ctx, sizeof(*info) +
412                               num_levels * sizeof(info->levels[0]));
413
414      info->base_var = var;
415      info->num_levels = num_levels;
416
417      const struct glsl_type *type = var->type;
418      for (int i = 0; i < num_levels; i++) {
419         info->levels[i].array_len = glsl_get_length(type);
420         type = glsl_get_array_element(type);
421
422         /* All levels start out initially as split */
423         info->levels[i].split = true;
424      }
425
426      _mesa_hash_table_insert(var_info_map, var, info);
427      has_array = true;
428   }
429
430   return has_array;
431}
432
433static struct array_var_info *
434get_array_var_info(nir_variable *var,
435                   struct hash_table *var_info_map)
436{
437   struct hash_entry *entry =
438      _mesa_hash_table_search(var_info_map, var);
439   return entry ? entry->data : NULL;
440}
441
442static struct array_var_info *
443get_array_deref_info(nir_deref_instr *deref,
444                     struct hash_table *var_info_map,
445                     nir_variable_mode modes)
446{
447   if (!nir_deref_mode_may_be(deref, modes))
448      return NULL;
449
450   nir_variable *var = nir_deref_instr_get_variable(deref);
451   if (var == NULL)
452      return NULL;
453
454   return get_array_var_info(var, var_info_map);
455}
456
457static void
458mark_array_deref_used(nir_deref_instr *deref,
459                      struct hash_table *var_info_map,
460                      nir_variable_mode modes,
461                      void *mem_ctx)
462{
463   struct array_var_info *info =
464      get_array_deref_info(deref, var_info_map, modes);
465   if (!info)
466      return;
467
468   nir_deref_path path;
469   nir_deref_path_init(&path, deref, mem_ctx);
470
471   /* Walk the path and look for indirects.  If we have an array deref with an
472    * indirect, mark the given level as not being split.
473    */
474   for (unsigned i = 0; i < info->num_levels; i++) {
475      nir_deref_instr *p = path.path[i + 1];
476      if (p->deref_type == nir_deref_type_array &&
477          !nir_src_is_const(p->arr.index))
478         info->levels[i].split = false;
479   }
480}
481
482static void
483mark_array_usage_impl(nir_function_impl *impl,
484                      struct hash_table *var_info_map,
485                      nir_variable_mode modes,
486                      void *mem_ctx)
487{
488   nir_foreach_block(block, impl) {
489      nir_foreach_instr(instr, block) {
490         if (instr->type != nir_instr_type_intrinsic)
491            continue;
492
493         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
494         switch (intrin->intrinsic) {
495         case nir_intrinsic_copy_deref:
496            mark_array_deref_used(nir_src_as_deref(intrin->src[1]),
497                                  var_info_map, modes, mem_ctx);
498            FALLTHROUGH;
499
500         case nir_intrinsic_load_deref:
501         case nir_intrinsic_store_deref:
502            mark_array_deref_used(nir_src_as_deref(intrin->src[0]),
503                                  var_info_map, modes, mem_ctx);
504            break;
505
506         default:
507            break;
508         }
509      }
510   }
511}
512
513static void
514create_split_array_vars(struct array_var_info *var_info,
515                        unsigned level,
516                        struct array_split *split,
517                        const char *name,
518                        nir_shader *shader,
519                        nir_function_impl *impl,
520                        void *mem_ctx)
521{
522   while (level < var_info->num_levels && !var_info->levels[level].split) {
523      name = ralloc_asprintf(mem_ctx, "%s[*]", name);
524      level++;
525   }
526
527   if (level == var_info->num_levels) {
528      /* We add parens to the variable name so it looks like "(foo[2][*])" so
529       * that further derefs will look like "(foo[2][*])[ssa_6]"
530       */
531      name = ralloc_asprintf(mem_ctx, "(%s)", name);
532
533      nir_variable_mode mode = var_info->base_var->data.mode;
534      if (mode == nir_var_function_temp) {
535         split->var = nir_local_variable_create(impl,
536                                                var_info->split_var_type, name);
537      } else {
538         split->var = nir_variable_create(shader, mode,
539                                          var_info->split_var_type, name);
540      }
541   } else {
542      assert(var_info->levels[level].split);
543      split->num_splits = var_info->levels[level].array_len;
544      split->splits = rzalloc_array(mem_ctx, struct array_split,
545                                    split->num_splits);
546      for (unsigned i = 0; i < split->num_splits; i++) {
547         create_split_array_vars(var_info, level + 1, &split->splits[i],
548                                 ralloc_asprintf(mem_ctx, "%s[%d]", name, i),
549                                 shader, impl, mem_ctx);
550      }
551   }
552}
553
554static bool
555split_var_list_arrays(nir_shader *shader,
556                      nir_function_impl *impl,
557                      struct exec_list *vars,
558                      nir_variable_mode mode,
559                      struct hash_table *var_info_map,
560                      void *mem_ctx)
561{
562   struct exec_list split_vars;
563   exec_list_make_empty(&split_vars);
564
565   nir_foreach_variable_in_list_safe(var, vars) {
566      if (var->data.mode != mode)
567         continue;
568
569      struct array_var_info *info = get_array_var_info(var, var_info_map);
570      if (!info)
571         continue;
572
573      bool has_split = false;
574      const struct glsl_type *split_type =
575         glsl_without_array_or_matrix(var->type);
576      for (int i = info->num_levels - 1; i >= 0; i--) {
577         if (info->levels[i].split) {
578            has_split = true;
579            continue;
580         }
581
582         /* If the original type was a matrix type, we'd like to keep that so
583          * we don't convert matrices into arrays.
584          */
585         if (i == info->num_levels - 1 &&
586             glsl_type_is_matrix(glsl_without_array(var->type))) {
587            split_type = glsl_matrix_type(glsl_get_base_type(split_type),
588                                          glsl_get_components(split_type),
589                                          info->levels[i].array_len);
590         } else {
591            split_type = glsl_array_type(split_type, info->levels[i].array_len, 0);
592         }
593      }
594
595      if (has_split) {
596         info->split_var_type = split_type;
597         /* To avoid list confusion (we'll be adding things as we split
598          * variables), pull all of the variables we plan to split off of the
599          * main variable list.
600          */
601         exec_node_remove(&var->node);
602         exec_list_push_tail(&split_vars, &var->node);
603      } else {
604         assert(split_type == glsl_get_bare_type(var->type));
605         /* If we're not modifying this variable, delete the info so we skip
606          * it faster in later passes.
607          */
608         _mesa_hash_table_remove_key(var_info_map, var);
609      }
610   }
611
612   nir_foreach_variable_in_list(var, &split_vars) {
613      struct array_var_info *info = get_array_var_info(var, var_info_map);
614      create_split_array_vars(info, 0, &info->root_split, var->name,
615                              shader, impl, mem_ctx);
616   }
617
618   return !exec_list_is_empty(&split_vars);
619}
620
621static bool
622deref_has_split_wildcard(nir_deref_path *path,
623                         struct array_var_info *info)
624{
625   if (info == NULL)
626      return false;
627
628   assert(path->path[0]->var == info->base_var);
629   for (unsigned i = 0; i < info->num_levels; i++) {
630      if (path->path[i + 1]->deref_type == nir_deref_type_array_wildcard &&
631          info->levels[i].split)
632         return true;
633   }
634
635   return false;
636}
637
638static bool
639array_path_is_out_of_bounds(nir_deref_path *path,
640                            struct array_var_info *info)
641{
642   if (info == NULL)
643      return false;
644
645   assert(path->path[0]->var == info->base_var);
646   for (unsigned i = 0; i < info->num_levels; i++) {
647      nir_deref_instr *p = path->path[i + 1];
648      if (p->deref_type == nir_deref_type_array_wildcard)
649         continue;
650
651      if (nir_src_is_const(p->arr.index) &&
652          nir_src_as_uint(p->arr.index) >= info->levels[i].array_len)
653         return true;
654   }
655
656   return false;
657}
658
659static void
660emit_split_copies(nir_builder *b,
661                  struct array_var_info *dst_info, nir_deref_path *dst_path,
662                  unsigned dst_level, nir_deref_instr *dst,
663                  struct array_var_info *src_info, nir_deref_path *src_path,
664                  unsigned src_level, nir_deref_instr *src)
665{
666   nir_deref_instr *dst_p, *src_p;
667
668   while ((dst_p = dst_path->path[dst_level + 1])) {
669      if (dst_p->deref_type == nir_deref_type_array_wildcard)
670         break;
671
672      dst = nir_build_deref_follower(b, dst, dst_p);
673      dst_level++;
674   }
675
676   while ((src_p = src_path->path[src_level + 1])) {
677      if (src_p->deref_type == nir_deref_type_array_wildcard)
678         break;
679
680      src = nir_build_deref_follower(b, src, src_p);
681      src_level++;
682   }
683
684   if (src_p == NULL || dst_p == NULL) {
685      assert(src_p == NULL && dst_p == NULL);
686      nir_copy_deref(b, dst, src);
687   } else {
688      assert(dst_p->deref_type == nir_deref_type_array_wildcard &&
689             src_p->deref_type == nir_deref_type_array_wildcard);
690
691      if ((dst_info && dst_info->levels[dst_level].split) ||
692          (src_info && src_info->levels[src_level].split)) {
693         /* There are no indirects at this level on one of the source or the
694          * destination so we are lowering it.
695          */
696         assert(glsl_get_length(dst_path->path[dst_level]->type) ==
697                glsl_get_length(src_path->path[src_level]->type));
698         unsigned len = glsl_get_length(dst_path->path[dst_level]->type);
699         for (unsigned i = 0; i < len; i++) {
700            emit_split_copies(b, dst_info, dst_path, dst_level + 1,
701                              nir_build_deref_array_imm(b, dst, i),
702                              src_info, src_path, src_level + 1,
703                              nir_build_deref_array_imm(b, src, i));
704         }
705      } else {
706         /* Neither side is being split so we just keep going */
707         emit_split_copies(b, dst_info, dst_path, dst_level + 1,
708                           nir_build_deref_array_wildcard(b, dst),
709                           src_info, src_path, src_level + 1,
710                           nir_build_deref_array_wildcard(b, src));
711      }
712   }
713}
714
715static void
716split_array_copies_impl(nir_function_impl *impl,
717                        struct hash_table *var_info_map,
718                        nir_variable_mode modes,
719                        void *mem_ctx)
720{
721   nir_builder b;
722   nir_builder_init(&b, impl);
723
724   nir_foreach_block(block, impl) {
725      nir_foreach_instr_safe(instr, block) {
726         if (instr->type != nir_instr_type_intrinsic)
727            continue;
728
729         nir_intrinsic_instr *copy = nir_instr_as_intrinsic(instr);
730         if (copy->intrinsic != nir_intrinsic_copy_deref)
731            continue;
732
733         nir_deref_instr *dst_deref = nir_src_as_deref(copy->src[0]);
734         nir_deref_instr *src_deref = nir_src_as_deref(copy->src[1]);
735
736         struct array_var_info *dst_info =
737            get_array_deref_info(dst_deref, var_info_map, modes);
738         struct array_var_info *src_info =
739            get_array_deref_info(src_deref, var_info_map, modes);
740
741         if (!src_info && !dst_info)
742            continue;
743
744         nir_deref_path dst_path, src_path;
745         nir_deref_path_init(&dst_path, dst_deref, mem_ctx);
746         nir_deref_path_init(&src_path, src_deref, mem_ctx);
747
748         if (!deref_has_split_wildcard(&dst_path, dst_info) &&
749             !deref_has_split_wildcard(&src_path, src_info))
750            continue;
751
752         b.cursor = nir_instr_remove(&copy->instr);
753
754         emit_split_copies(&b, dst_info, &dst_path, 0, dst_path.path[0],
755                               src_info, &src_path, 0, src_path.path[0]);
756      }
757   }
758}
759
760static void
761split_array_access_impl(nir_function_impl *impl,
762                        struct hash_table *var_info_map,
763                        nir_variable_mode modes,
764                        void *mem_ctx)
765{
766   nir_builder b;
767   nir_builder_init(&b, impl);
768
769   nir_foreach_block(block, impl) {
770      nir_foreach_instr_safe(instr, block) {
771         if (instr->type == nir_instr_type_deref) {
772            /* Clean up any dead derefs we find lying around.  They may refer
773             * to variables we're planning to split.
774             */
775            nir_deref_instr *deref = nir_instr_as_deref(instr);
776            if (nir_deref_mode_may_be(deref, modes))
777               nir_deref_instr_remove_if_unused(deref);
778            continue;
779         }
780
781         if (instr->type != nir_instr_type_intrinsic)
782            continue;
783
784         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
785         if (intrin->intrinsic != nir_intrinsic_load_deref &&
786             intrin->intrinsic != nir_intrinsic_store_deref &&
787             intrin->intrinsic != nir_intrinsic_copy_deref)
788            continue;
789
790         const unsigned num_derefs =
791            intrin->intrinsic == nir_intrinsic_copy_deref ? 2 : 1;
792
793         for (unsigned d = 0; d < num_derefs; d++) {
794            nir_deref_instr *deref = nir_src_as_deref(intrin->src[d]);
795
796            struct array_var_info *info =
797               get_array_deref_info(deref, var_info_map, modes);
798            if (!info)
799               continue;
800
801            nir_deref_path path;
802            nir_deref_path_init(&path, deref, mem_ctx);
803
804            b.cursor = nir_before_instr(&intrin->instr);
805
806            if (array_path_is_out_of_bounds(&path, info)) {
807               /* If one of the derefs is out-of-bounds, we just delete the
808                * instruction.  If a destination is out of bounds, then it may
809                * have been in-bounds prior to shrinking so we don't want to
810                * accidentally stomp something.  However, we've already proven
811                * that it will never be read so it's safe to delete.  If a
812                * source is out of bounds then it is loading random garbage.
813                * For loads, we replace their uses with an undef instruction
814                * and for copies we just delete the copy since it was writing
815                * undefined garbage anyway and we may as well leave the random
816                * garbage in the destination alone.
817                */
818               if (intrin->intrinsic == nir_intrinsic_load_deref) {
819                  nir_ssa_def *u =
820                     nir_ssa_undef(&b, intrin->dest.ssa.num_components,
821                                       intrin->dest.ssa.bit_size);
822                  nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
823                                           u);
824               }
825               nir_instr_remove(&intrin->instr);
826               for (unsigned i = 0; i < num_derefs; i++)
827                  nir_deref_instr_remove_if_unused(nir_src_as_deref(intrin->src[i]));
828               break;
829            }
830
831            struct array_split *split = &info->root_split;
832            for (unsigned i = 0; i < info->num_levels; i++) {
833               if (info->levels[i].split) {
834                  nir_deref_instr *p = path.path[i + 1];
835                  unsigned index = nir_src_as_uint(p->arr.index);
836                  assert(index < info->levels[i].array_len);
837                  split = &split->splits[index];
838               }
839            }
840            assert(!split->splits && split->var);
841
842            nir_deref_instr *new_deref = nir_build_deref_var(&b, split->var);
843            for (unsigned i = 0; i < info->num_levels; i++) {
844               if (!info->levels[i].split) {
845                  new_deref = nir_build_deref_follower(&b, new_deref,
846                                                       path.path[i + 1]);
847               }
848            }
849            assert(new_deref->type == deref->type);
850
851            /* Rewrite the deref source to point to the split one */
852            nir_instr_rewrite_src(&intrin->instr, &intrin->src[d],
853                                  nir_src_for_ssa(&new_deref->dest.ssa));
854            nir_deref_instr_remove_if_unused(deref);
855         }
856      }
857   }
858}
859
860/** A pass for splitting arrays of vectors into multiple variables
861 *
862 * This pass looks at arrays (possibly multiple levels) of vectors (not
863 * structures or other types) and tries to split them into piles of variables,
864 * one for each array element.  The heuristic used is simple: If a given array
865 * level is never used with an indirect, that array level will get split.
866 *
867 * This pass probably could handles structures easily enough but making a pass
868 * that could see through an array of structures of arrays would be difficult
869 * so it's best to just run nir_split_struct_vars first.
870 */
871bool
872nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
873{
874   void *mem_ctx = ralloc_context(NULL);
875   struct hash_table *var_info_map = _mesa_pointer_hash_table_create(mem_ctx);
876   struct set *complex_vars = NULL;
877
878   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
879
880   bool has_global_array = false;
881   if (modes & nir_var_shader_temp) {
882      has_global_array = init_var_list_array_infos(shader,
883                                                   &shader->variables,
884                                                   nir_var_shader_temp,
885                                                   var_info_map,
886                                                   &complex_vars,
887                                                   mem_ctx);
888   }
889
890   bool has_any_array = false;
891   nir_foreach_function(function, shader) {
892      if (!function->impl)
893         continue;
894
895      bool has_local_array = false;
896      if (modes & nir_var_function_temp) {
897         has_local_array = init_var_list_array_infos(shader,
898                                                     &function->impl->locals,
899                                                     nir_var_function_temp,
900                                                     var_info_map,
901                                                     &complex_vars,
902                                                     mem_ctx);
903      }
904
905      if (has_global_array || has_local_array) {
906         has_any_array = true;
907         mark_array_usage_impl(function->impl, var_info_map, modes, mem_ctx);
908      }
909   }
910
911   /* If we failed to find any arrays of arrays, bail early. */
912   if (!has_any_array) {
913      ralloc_free(mem_ctx);
914      nir_shader_preserve_all_metadata(shader);
915      return false;
916   }
917
918   bool has_global_splits = false;
919   if (modes & nir_var_shader_temp) {
920      has_global_splits = split_var_list_arrays(shader, NULL,
921                                                &shader->variables,
922                                                nir_var_shader_temp,
923                                                var_info_map, mem_ctx);
924   }
925
926   bool progress = false;
927   nir_foreach_function(function, shader) {
928      if (!function->impl)
929         continue;
930
931      bool has_local_splits = false;
932      if (modes & nir_var_function_temp) {
933         has_local_splits = split_var_list_arrays(shader, function->impl,
934                                                  &function->impl->locals,
935                                                  nir_var_function_temp,
936                                                  var_info_map, mem_ctx);
937      }
938
939      if (has_global_splits || has_local_splits) {
940         split_array_copies_impl(function->impl, var_info_map, modes, mem_ctx);
941         split_array_access_impl(function->impl, var_info_map, modes, mem_ctx);
942
943         nir_metadata_preserve(function->impl, nir_metadata_block_index |
944                                               nir_metadata_dominance);
945         progress = true;
946      } else {
947         nir_metadata_preserve(function->impl, nir_metadata_all);
948      }
949   }
950
951   ralloc_free(mem_ctx);
952
953   return progress;
954}
955
956struct array_level_usage {
957   unsigned array_len;
958
959   /* The value UINT_MAX will be used to indicate an indirect */
960   unsigned max_read;
961   unsigned max_written;
962
963   /* True if there is a copy that isn't to/from a shrinkable array */
964   bool has_external_copy;
965   struct set *levels_copied;
966};
967
968struct vec_var_usage {
969   /* Convenience set of all components this variable has */
970   nir_component_mask_t all_comps;
971
972   nir_component_mask_t comps_read;
973   nir_component_mask_t comps_written;
974
975   nir_component_mask_t comps_kept;
976
977   /* True if there is a copy that isn't to/from a shrinkable vector */
978   bool has_external_copy;
979   bool has_complex_use;
980   struct set *vars_copied;
981
982   unsigned num_levels;
983   struct array_level_usage levels[0];
984};
985
986static struct vec_var_usage *
987get_vec_var_usage(nir_variable *var,
988                  struct hash_table *var_usage_map,
989                  bool add_usage_entry, void *mem_ctx)
990{
991   struct hash_entry *entry = _mesa_hash_table_search(var_usage_map, var);
992   if (entry)
993      return entry->data;
994
995   if (!add_usage_entry)
996      return NULL;
997
998   /* Check to make sure that we are working with an array of vectors.  We
999    * don't bother to shrink single vectors because we figure that we can
1000    * clean it up better with SSA than by inserting piles of vecN instructions
1001    * to compact results.
1002    */
1003   int num_levels = num_array_levels_in_array_of_vector_type(var->type);
1004   if (num_levels < 1)
1005      return NULL; /* Not an array of vectors */
1006
1007   struct vec_var_usage *usage =
1008      rzalloc_size(mem_ctx, sizeof(*usage) +
1009                            num_levels * sizeof(usage->levels[0]));
1010
1011   usage->num_levels = num_levels;
1012   const struct glsl_type *type = var->type;
1013   for (unsigned i = 0; i < num_levels; i++) {
1014      usage->levels[i].array_len = glsl_get_length(type);
1015      type = glsl_get_array_element(type);
1016   }
1017   assert(glsl_type_is_vector_or_scalar(type));
1018
1019   usage->all_comps = (1 << glsl_get_components(type)) - 1;
1020
1021   _mesa_hash_table_insert(var_usage_map, var, usage);
1022
1023   return usage;
1024}
1025
1026static struct vec_var_usage *
1027get_vec_deref_usage(nir_deref_instr *deref,
1028                    struct hash_table *var_usage_map,
1029                    nir_variable_mode modes,
1030                    bool add_usage_entry, void *mem_ctx)
1031{
1032   if (!nir_deref_mode_may_be(deref, modes))
1033      return NULL;
1034
1035   nir_variable *var = nir_deref_instr_get_variable(deref);
1036   if (var == NULL)
1037      return NULL;
1038
1039   return get_vec_var_usage(nir_deref_instr_get_variable(deref),
1040                            var_usage_map, add_usage_entry, mem_ctx);
1041}
1042
1043static void
1044mark_deref_if_complex(nir_deref_instr *deref,
1045                      struct hash_table *var_usage_map,
1046                      nir_variable_mode modes,
1047                      void *mem_ctx)
1048{
1049   /* Only bother with var derefs because nir_deref_instr_has_complex_use is
1050    * recursive.
1051    */
1052   if (deref->deref_type != nir_deref_type_var)
1053      return;
1054
1055   if (!(deref->var->data.mode & modes))
1056      return;
1057
1058   if (!nir_deref_instr_has_complex_use(deref))
1059      return;
1060
1061   struct vec_var_usage *usage =
1062      get_vec_var_usage(deref->var, var_usage_map, true, mem_ctx);
1063   if (!usage)
1064      return;
1065
1066   usage->has_complex_use = true;
1067}
1068
1069static void
1070mark_deref_used(nir_deref_instr *deref,
1071                nir_component_mask_t comps_read,
1072                nir_component_mask_t comps_written,
1073                nir_deref_instr *copy_deref,
1074                struct hash_table *var_usage_map,
1075                nir_variable_mode modes,
1076                void *mem_ctx)
1077{
1078   if (!nir_deref_mode_may_be(deref, modes))
1079      return;
1080
1081   nir_variable *var = nir_deref_instr_get_variable(deref);
1082   if (var == NULL)
1083      return;
1084
1085   struct vec_var_usage *usage =
1086      get_vec_var_usage(var, var_usage_map, true, mem_ctx);
1087   if (!usage)
1088      return;
1089
1090   usage->comps_read |= comps_read & usage->all_comps;
1091   usage->comps_written |= comps_written & usage->all_comps;
1092
1093   struct vec_var_usage *copy_usage = NULL;
1094   if (copy_deref) {
1095      copy_usage = get_vec_deref_usage(copy_deref, var_usage_map, modes,
1096                                       true, mem_ctx);
1097      if (copy_usage) {
1098         if (usage->vars_copied == NULL) {
1099            usage->vars_copied = _mesa_pointer_set_create(mem_ctx);
1100         }
1101         _mesa_set_add(usage->vars_copied, copy_usage);
1102      } else {
1103         usage->has_external_copy = true;
1104      }
1105   }
1106
1107   nir_deref_path path;
1108   nir_deref_path_init(&path, deref, mem_ctx);
1109
1110   nir_deref_path copy_path;
1111   if (copy_usage)
1112      nir_deref_path_init(&copy_path, copy_deref, mem_ctx);
1113
1114   unsigned copy_i = 0;
1115   for (unsigned i = 0; i < usage->num_levels; i++) {
1116      struct array_level_usage *level = &usage->levels[i];
1117      nir_deref_instr *deref = path.path[i + 1];
1118      assert(deref->deref_type == nir_deref_type_array ||
1119             deref->deref_type == nir_deref_type_array_wildcard);
1120
1121      unsigned max_used;
1122      if (deref->deref_type == nir_deref_type_array) {
1123         max_used = nir_src_is_const(deref->arr.index) ?
1124                    nir_src_as_uint(deref->arr.index) : UINT_MAX;
1125      } else {
1126         /* For wildcards, we read or wrote the whole thing. */
1127         assert(deref->deref_type == nir_deref_type_array_wildcard);
1128         max_used = level->array_len - 1;
1129
1130         if (copy_usage) {
1131            /* Match each wildcard level with the level on copy_usage */
1132            for (; copy_path.path[copy_i + 1]; copy_i++) {
1133               if (copy_path.path[copy_i + 1]->deref_type ==
1134                   nir_deref_type_array_wildcard)
1135                  break;
1136            }
1137            struct array_level_usage *copy_level =
1138               &copy_usage->levels[copy_i++];
1139
1140            if (level->levels_copied == NULL) {
1141               level->levels_copied = _mesa_pointer_set_create(mem_ctx);
1142            }
1143            _mesa_set_add(level->levels_copied, copy_level);
1144         } else {
1145            /* We have a wildcard and it comes from a variable we aren't
1146             * tracking; flag it and we'll know to not shorten this array.
1147             */
1148            level->has_external_copy = true;
1149         }
1150      }
1151
1152      if (comps_written)
1153         level->max_written = MAX2(level->max_written, max_used);
1154      if (comps_read)
1155         level->max_read = MAX2(level->max_read, max_used);
1156   }
1157}
1158
1159static bool
1160src_is_load_deref(nir_src src, nir_src deref_src)
1161{
1162   nir_intrinsic_instr *load = nir_src_as_intrinsic(src);
1163   if (load == NULL || load->intrinsic != nir_intrinsic_load_deref)
1164      return false;
1165
1166   assert(load->src[0].is_ssa);
1167
1168   return load->src[0].ssa == deref_src.ssa;
1169}
1170
1171/* Returns all non-self-referential components of a store instruction.  A
1172 * component is self-referential if it comes from the same component of a load
1173 * instruction on the same deref.  If the only data in a particular component
1174 * of a variable came directly from that component then it's undefined.  The
1175 * only way to get defined data into a component of a variable is for it to
1176 * get written there by something outside or from a different component.
1177 *
1178 * This is a fairly common pattern in shaders that come from either GLSL IR or
1179 * GLSLang because both glsl_to_nir and GLSLang implement write-masking with
1180 * load-vec-store.
1181 */
1182static nir_component_mask_t
1183get_non_self_referential_store_comps(nir_intrinsic_instr *store)
1184{
1185   nir_component_mask_t comps = nir_intrinsic_write_mask(store);
1186
1187   assert(store->src[1].is_ssa);
1188   nir_instr *src_instr = store->src[1].ssa->parent_instr;
1189   if (src_instr->type != nir_instr_type_alu)
1190      return comps;
1191
1192   nir_alu_instr *src_alu = nir_instr_as_alu(src_instr);
1193
1194   if (src_alu->op == nir_op_mov) {
1195      /* If it's just a swizzle of a load from the same deref, discount any
1196       * channels that don't move in the swizzle.
1197       */
1198      if (src_is_load_deref(src_alu->src[0].src, store->src[0])) {
1199         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
1200            if (src_alu->src[0].swizzle[i] == i)
1201               comps &= ~(1u << i);
1202         }
1203      }
1204   } else if (nir_op_is_vec(src_alu->op)) {
1205      /* If it's a vec, discount any channels that are just loads from the
1206       * same deref put in the same spot.
1207       */
1208      for (unsigned i = 0; i < nir_op_infos[src_alu->op].num_inputs; i++) {
1209         if (src_is_load_deref(src_alu->src[i].src, store->src[0]) &&
1210             src_alu->src[i].swizzle[0] == i)
1211            comps &= ~(1u << i);
1212      }
1213   }
1214
1215   return comps;
1216}
1217
1218static void
1219find_used_components_impl(nir_function_impl *impl,
1220                          struct hash_table *var_usage_map,
1221                          nir_variable_mode modes,
1222                          void *mem_ctx)
1223{
1224   nir_foreach_block(block, impl) {
1225      nir_foreach_instr(instr, block) {
1226         if (instr->type == nir_instr_type_deref) {
1227            mark_deref_if_complex(nir_instr_as_deref(instr),
1228                                  var_usage_map, modes, mem_ctx);
1229         }
1230
1231         if (instr->type != nir_instr_type_intrinsic)
1232            continue;
1233
1234         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1235         switch (intrin->intrinsic) {
1236         case nir_intrinsic_load_deref:
1237            mark_deref_used(nir_src_as_deref(intrin->src[0]),
1238                            nir_ssa_def_components_read(&intrin->dest.ssa), 0,
1239                            NULL, var_usage_map, modes, mem_ctx);
1240            break;
1241
1242         case nir_intrinsic_store_deref:
1243            mark_deref_used(nir_src_as_deref(intrin->src[0]),
1244                            0, get_non_self_referential_store_comps(intrin),
1245                            NULL, var_usage_map, modes, mem_ctx);
1246            break;
1247
1248         case nir_intrinsic_copy_deref: {
1249            /* Just mark everything used for copies. */
1250            nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1251            nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1252            mark_deref_used(dst, 0, ~0, src, var_usage_map, modes, mem_ctx);
1253            mark_deref_used(src, ~0, 0, dst, var_usage_map, modes, mem_ctx);
1254            break;
1255         }
1256
1257         default:
1258            break;
1259         }
1260      }
1261   }
1262}
1263
1264static bool
1265shrink_vec_var_list(struct exec_list *vars,
1266                    nir_variable_mode mode,
1267                    struct hash_table *var_usage_map)
1268{
1269   /* Initialize the components kept field of each variable.  This is the
1270    * AND of the components written and components read.  If a component is
1271    * written but never read, it's dead.  If it is read but never written,
1272    * then all values read are undefined garbage and we may as well not read
1273    * them.
1274    *
1275    * The same logic applies to the array length.  We make the array length
1276    * the minimum needed required length between read and write and plan to
1277    * discard any OOB access.  The one exception here is indirect writes
1278    * because we don't know where they will land and we can't shrink an array
1279    * with indirect writes because previously in-bounds writes may become
1280    * out-of-bounds and have undefined behavior.
1281    *
1282    * Also, if we have a copy that to/from something we can't shrink, we need
1283    * to leave components and array_len of any wildcards alone.
1284    */
1285   nir_foreach_variable_in_list(var, vars) {
1286      if (var->data.mode != mode)
1287         continue;
1288
1289      struct vec_var_usage *usage =
1290         get_vec_var_usage(var, var_usage_map, false, NULL);
1291      if (!usage)
1292         continue;
1293
1294      assert(usage->comps_kept == 0);
1295      if (usage->has_external_copy || usage->has_complex_use)
1296         usage->comps_kept = usage->all_comps;
1297      else
1298         usage->comps_kept = usage->comps_read & usage->comps_written;
1299
1300      for (unsigned i = 0; i < usage->num_levels; i++) {
1301         struct array_level_usage *level = &usage->levels[i];
1302         assert(level->array_len > 0);
1303
1304         if (level->max_written == UINT_MAX || level->has_external_copy ||
1305             usage->has_complex_use)
1306            continue; /* Can't shrink */
1307
1308         unsigned max_used = MIN2(level->max_read, level->max_written);
1309         level->array_len = MIN2(max_used, level->array_len - 1) + 1;
1310      }
1311   }
1312
1313   /* In order for variable copies to work, we have to have the same data type
1314    * on the source and the destination.  In order to satisfy this, we run a
1315    * little fixed-point algorithm to transitively ensure that we get enough
1316    * components and array elements for this to hold for all copies.
1317    */
1318   bool fp_progress;
1319   do {
1320      fp_progress = false;
1321      nir_foreach_variable_in_list(var, vars) {
1322         if (var->data.mode != mode)
1323            continue;
1324
1325         struct vec_var_usage *var_usage =
1326            get_vec_var_usage(var, var_usage_map, false, NULL);
1327         if (!var_usage || !var_usage->vars_copied)
1328            continue;
1329
1330         set_foreach(var_usage->vars_copied, copy_entry) {
1331            struct vec_var_usage *copy_usage = (void *)copy_entry->key;
1332            if (copy_usage->comps_kept != var_usage->comps_kept) {
1333               nir_component_mask_t comps_kept =
1334                  (var_usage->comps_kept | copy_usage->comps_kept);
1335               var_usage->comps_kept = comps_kept;
1336               copy_usage->comps_kept = comps_kept;
1337               fp_progress = true;
1338            }
1339         }
1340
1341         for (unsigned i = 0; i < var_usage->num_levels; i++) {
1342            struct array_level_usage *var_level = &var_usage->levels[i];
1343            if (!var_level->levels_copied)
1344               continue;
1345
1346            set_foreach(var_level->levels_copied, copy_entry) {
1347               struct array_level_usage *copy_level = (void *)copy_entry->key;
1348               if (var_level->array_len != copy_level->array_len) {
1349                  unsigned array_len =
1350                     MAX2(var_level->array_len, copy_level->array_len);
1351                  var_level->array_len = array_len;
1352                  copy_level->array_len = array_len;
1353                  fp_progress = true;
1354               }
1355            }
1356         }
1357      }
1358   } while (fp_progress);
1359
1360   bool vars_shrunk = false;
1361   nir_foreach_variable_in_list_safe(var, vars) {
1362      if (var->data.mode != mode)
1363         continue;
1364
1365      struct vec_var_usage *usage =
1366         get_vec_var_usage(var, var_usage_map, false, NULL);
1367      if (!usage)
1368         continue;
1369
1370      bool shrunk = false;
1371      const struct glsl_type *vec_type = var->type;
1372      for (unsigned i = 0; i < usage->num_levels; i++) {
1373         /* If we've reduced the array to zero elements at some level, just
1374          * set comps_kept to 0 and delete the variable.
1375          */
1376         if (usage->levels[i].array_len == 0) {
1377            usage->comps_kept = 0;
1378            break;
1379         }
1380
1381         assert(usage->levels[i].array_len <= glsl_get_length(vec_type));
1382         if (usage->levels[i].array_len < glsl_get_length(vec_type))
1383            shrunk = true;
1384         vec_type = glsl_get_array_element(vec_type);
1385      }
1386      assert(glsl_type_is_vector_or_scalar(vec_type));
1387
1388      assert(usage->comps_kept == (usage->comps_kept & usage->all_comps));
1389      if (usage->comps_kept != usage->all_comps)
1390         shrunk = true;
1391
1392      if (usage->comps_kept == 0) {
1393         /* This variable is dead, remove it */
1394         vars_shrunk = true;
1395         exec_node_remove(&var->node);
1396         continue;
1397      }
1398
1399      if (!shrunk) {
1400         /* This variable doesn't need to be shrunk.  Remove it from the
1401          * hash table so later steps will ignore it.
1402          */
1403         _mesa_hash_table_remove_key(var_usage_map, var);
1404         continue;
1405      }
1406
1407      /* Build the new var type */
1408      unsigned new_num_comps = util_bitcount(usage->comps_kept);
1409      const struct glsl_type *new_type =
1410         glsl_vector_type(glsl_get_base_type(vec_type), new_num_comps);
1411      for (int i = usage->num_levels - 1; i >= 0; i--) {
1412         assert(usage->levels[i].array_len > 0);
1413         /* If the original type was a matrix type, we'd like to keep that so
1414          * we don't convert matrices into arrays.
1415          */
1416         if (i == usage->num_levels - 1 &&
1417             glsl_type_is_matrix(glsl_without_array(var->type)) &&
1418             new_num_comps > 1 && usage->levels[i].array_len > 1) {
1419            new_type = glsl_matrix_type(glsl_get_base_type(new_type),
1420                                        new_num_comps,
1421                                        usage->levels[i].array_len);
1422         } else {
1423            new_type = glsl_array_type(new_type, usage->levels[i].array_len, 0);
1424         }
1425      }
1426      var->type = new_type;
1427
1428      vars_shrunk = true;
1429   }
1430
1431   return vars_shrunk;
1432}
1433
1434static bool
1435vec_deref_is_oob(nir_deref_instr *deref,
1436                 struct vec_var_usage *usage)
1437{
1438   nir_deref_path path;
1439   nir_deref_path_init(&path, deref, NULL);
1440
1441   bool oob = false;
1442   for (unsigned i = 0; i < usage->num_levels; i++) {
1443      nir_deref_instr *p = path.path[i + 1];
1444      if (p->deref_type == nir_deref_type_array_wildcard)
1445         continue;
1446
1447      if (nir_src_is_const(p->arr.index) &&
1448          nir_src_as_uint(p->arr.index) >= usage->levels[i].array_len) {
1449         oob = true;
1450         break;
1451      }
1452   }
1453
1454   nir_deref_path_finish(&path);
1455
1456   return oob;
1457}
1458
1459static bool
1460vec_deref_is_dead_or_oob(nir_deref_instr *deref,
1461                         struct hash_table *var_usage_map,
1462                         nir_variable_mode modes)
1463{
1464   struct vec_var_usage *usage =
1465      get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1466   if (!usage)
1467      return false;
1468
1469   return usage->comps_kept == 0 || vec_deref_is_oob(deref, usage);
1470}
1471
1472static void
1473shrink_vec_var_access_impl(nir_function_impl *impl,
1474                           struct hash_table *var_usage_map,
1475                           nir_variable_mode modes)
1476{
1477   nir_builder b;
1478   nir_builder_init(&b, impl);
1479
1480   nir_foreach_block(block, impl) {
1481      nir_foreach_instr_safe(instr, block) {
1482         switch (instr->type) {
1483         case nir_instr_type_deref: {
1484            nir_deref_instr *deref = nir_instr_as_deref(instr);
1485            if (!nir_deref_mode_may_be(deref, modes))
1486               break;
1487
1488            /* Clean up any dead derefs we find lying around.  They may refer
1489             * to variables we've deleted.
1490             */
1491            if (nir_deref_instr_remove_if_unused(deref))
1492               break;
1493
1494            /* Update the type in the deref to keep the types consistent as
1495             * you walk down the chain.  We don't need to check if this is one
1496             * of the derefs we're shrinking because this is a no-op if it
1497             * isn't.  The worst that could happen is that we accidentally fix
1498             * an invalid deref.
1499             */
1500            if (deref->deref_type == nir_deref_type_var) {
1501               deref->type = deref->var->type;
1502            } else if (deref->deref_type == nir_deref_type_array ||
1503                       deref->deref_type == nir_deref_type_array_wildcard) {
1504               nir_deref_instr *parent = nir_deref_instr_parent(deref);
1505               assert(glsl_type_is_array(parent->type) ||
1506                      glsl_type_is_matrix(parent->type));
1507               deref->type = glsl_get_array_element(parent->type);
1508            }
1509            break;
1510         }
1511
1512         case nir_instr_type_intrinsic: {
1513            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1514
1515            /* If we have a copy whose source or destination has been deleted
1516             * because we determined the variable was dead, then we just
1517             * delete the copy instruction.  If the source variable was dead
1518             * then it was writing undefined garbage anyway and if it's the
1519             * destination variable that's dead then the write isn't needed.
1520             */
1521            if (intrin->intrinsic == nir_intrinsic_copy_deref) {
1522               nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1523               nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1524               if (vec_deref_is_dead_or_oob(dst, var_usage_map, modes) ||
1525                   vec_deref_is_dead_or_oob(src, var_usage_map, modes)) {
1526                  nir_instr_remove(&intrin->instr);
1527                  nir_deref_instr_remove_if_unused(dst);
1528                  nir_deref_instr_remove_if_unused(src);
1529               }
1530               continue;
1531            }
1532
1533            if (intrin->intrinsic != nir_intrinsic_load_deref &&
1534                intrin->intrinsic != nir_intrinsic_store_deref)
1535               continue;
1536
1537            nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1538            if (!nir_deref_mode_may_be(deref, modes))
1539               continue;
1540
1541            struct vec_var_usage *usage =
1542               get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1543            if (!usage)
1544               continue;
1545
1546            if (usage->comps_kept == 0 || vec_deref_is_oob(deref, usage)) {
1547               if (intrin->intrinsic == nir_intrinsic_load_deref) {
1548                  nir_ssa_def *u =
1549                     nir_ssa_undef(&b, intrin->dest.ssa.num_components,
1550                                       intrin->dest.ssa.bit_size);
1551                  nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
1552                                           u);
1553               }
1554               nir_instr_remove(&intrin->instr);
1555               nir_deref_instr_remove_if_unused(deref);
1556               continue;
1557            }
1558
1559            /* If we're not dropping any components, there's no need to
1560             * compact vectors.
1561             */
1562            if (usage->comps_kept == usage->all_comps)
1563               continue;
1564
1565            if (intrin->intrinsic == nir_intrinsic_load_deref) {
1566               b.cursor = nir_after_instr(&intrin->instr);
1567
1568               nir_ssa_def *undef =
1569                  nir_ssa_undef(&b, 1, intrin->dest.ssa.bit_size);
1570               nir_ssa_def *vec_srcs[NIR_MAX_VEC_COMPONENTS];
1571               unsigned c = 0;
1572               for (unsigned i = 0; i < intrin->num_components; i++) {
1573                  if (usage->comps_kept & (1u << i))
1574                     vec_srcs[i] = nir_channel(&b, &intrin->dest.ssa, c++);
1575                  else
1576                     vec_srcs[i] = undef;
1577               }
1578               nir_ssa_def *vec = nir_vec(&b, vec_srcs, intrin->num_components);
1579
1580               nir_ssa_def_rewrite_uses_after(&intrin->dest.ssa,
1581                                              vec,
1582                                              vec->parent_instr);
1583
1584               /* The SSA def is now only used by the swizzle.  It's safe to
1585                * shrink the number of components.
1586                */
1587               assert(list_length(&intrin->dest.ssa.uses) == c);
1588               intrin->num_components = c;
1589               intrin->dest.ssa.num_components = c;
1590            } else {
1591               nir_component_mask_t write_mask =
1592                  nir_intrinsic_write_mask(intrin);
1593
1594               unsigned swizzle[NIR_MAX_VEC_COMPONENTS];
1595               nir_component_mask_t new_write_mask = 0;
1596               unsigned c = 0;
1597               for (unsigned i = 0; i < intrin->num_components; i++) {
1598                  if (usage->comps_kept & (1u << i)) {
1599                     swizzle[c] = i;
1600                     if (write_mask & (1u << i))
1601                        new_write_mask |= 1u << c;
1602                     c++;
1603                  }
1604               }
1605
1606               b.cursor = nir_before_instr(&intrin->instr);
1607
1608               nir_ssa_def *swizzled =
1609                  nir_swizzle(&b, intrin->src[1].ssa, swizzle, c);
1610
1611               /* Rewrite to use the compacted source */
1612               nir_instr_rewrite_src(&intrin->instr, &intrin->src[1],
1613                                     nir_src_for_ssa(swizzled));
1614               nir_intrinsic_set_write_mask(intrin, new_write_mask);
1615               intrin->num_components = c;
1616            }
1617            break;
1618         }
1619
1620         default:
1621            break;
1622         }
1623      }
1624   }
1625}
1626
1627static bool
1628function_impl_has_vars_with_modes(nir_function_impl *impl,
1629                                  nir_variable_mode modes)
1630{
1631   nir_shader *shader = impl->function->shader;
1632
1633   if (modes & ~nir_var_function_temp) {
1634      nir_foreach_variable_with_modes(var, shader,
1635                                      modes & ~nir_var_function_temp)
1636         return true;
1637   }
1638
1639   if ((modes & nir_var_function_temp) && !exec_list_is_empty(&impl->locals))
1640      return true;
1641
1642   return false;
1643}
1644
1645/** Attempt to shrink arrays of vectors
1646 *
1647 * This pass looks at variables which contain a vector or an array (possibly
1648 * multiple dimensions) of vectors and attempts to lower to a smaller vector
1649 * or array.  If the pass can prove that a component of a vector (or array of
1650 * vectors) is never really used, then that component will be removed.
1651 * Similarly, the pass attempts to shorten arrays based on what elements it
1652 * can prove are never read or never contain valid data.
1653 */
1654bool
1655nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
1656{
1657   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
1658
1659   void *mem_ctx = ralloc_context(NULL);
1660
1661   struct hash_table *var_usage_map =
1662      _mesa_pointer_hash_table_create(mem_ctx);
1663
1664   bool has_vars_to_shrink = false;
1665   nir_foreach_function(function, shader) {
1666      if (!function->impl)
1667         continue;
1668
1669      /* Don't even bother crawling the IR if we don't have any variables.
1670       * Given that this pass deletes any unused variables, it's likely that
1671       * we will be in this scenario eventually.
1672       */
1673      if (function_impl_has_vars_with_modes(function->impl, modes)) {
1674         has_vars_to_shrink = true;
1675         find_used_components_impl(function->impl, var_usage_map,
1676                                   modes, mem_ctx);
1677      }
1678   }
1679   if (!has_vars_to_shrink) {
1680      ralloc_free(mem_ctx);
1681      nir_shader_preserve_all_metadata(shader);
1682      return false;
1683   }
1684
1685   bool globals_shrunk = false;
1686   if (modes & nir_var_shader_temp) {
1687      globals_shrunk = shrink_vec_var_list(&shader->variables,
1688                                           nir_var_shader_temp,
1689                                           var_usage_map);
1690   }
1691
1692   bool progress = false;
1693   nir_foreach_function(function, shader) {
1694      if (!function->impl)
1695         continue;
1696
1697      bool locals_shrunk = false;
1698      if (modes & nir_var_function_temp) {
1699         locals_shrunk = shrink_vec_var_list(&function->impl->locals,
1700                                             nir_var_function_temp,
1701                                             var_usage_map);
1702      }
1703
1704      if (globals_shrunk || locals_shrunk) {
1705         shrink_vec_var_access_impl(function->impl, var_usage_map, modes);
1706
1707         nir_metadata_preserve(function->impl, nir_metadata_block_index |
1708                                               nir_metadata_dominance);
1709         progress = true;
1710      } else {
1711         nir_metadata_preserve(function->impl, nir_metadata_all);
1712      }
1713   }
1714
1715   ralloc_free(mem_ctx);
1716
1717   return progress;
1718}
1719