1/*
2 * Copyright © 2016 Broadcom
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "nir.h"
25#include "nir_builder.h"
26#include "nir_deref.h"
27
28/** @file nir_lower_io_to_scalar.c
29 *
30 * Replaces nir_load_input/nir_store_output operations with num_components !=
31 * 1 with individual per-channel operations.
32 */
33
34static void
35lower_load_input_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
36{
37   b->cursor = nir_before_instr(&intr->instr);
38
39   assert(intr->dest.is_ssa);
40
41   nir_ssa_def *loads[NIR_MAX_VEC_COMPONENTS];
42
43   for (unsigned i = 0; i < intr->num_components; i++) {
44      nir_intrinsic_instr *chan_intr =
45         nir_intrinsic_instr_create(b->shader, intr->intrinsic);
46      nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
47                        1, intr->dest.ssa.bit_size, NULL);
48      chan_intr->num_components = 1;
49
50      nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
51      nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
52      nir_intrinsic_set_dest_type(chan_intr, nir_intrinsic_dest_type(intr));
53      nir_intrinsic_set_io_semantics(chan_intr, nir_intrinsic_io_semantics(intr));
54      /* offset */
55      nir_src_copy(&chan_intr->src[0], &intr->src[0]);
56
57      nir_builder_instr_insert(b, &chan_intr->instr);
58
59      loads[i] = &chan_intr->dest.ssa;
60   }
61
62   nir_ssa_def_rewrite_uses(&intr->dest.ssa,
63                            nir_vec(b, loads, intr->num_components));
64   nir_instr_remove(&intr->instr);
65}
66
67static void
68lower_store_output_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
69{
70   b->cursor = nir_before_instr(&intr->instr);
71
72   nir_ssa_def *value = nir_ssa_for_src(b, intr->src[0], intr->num_components);
73
74   for (unsigned i = 0; i < intr->num_components; i++) {
75      if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
76         continue;
77
78      nir_intrinsic_instr *chan_intr =
79         nir_intrinsic_instr_create(b->shader, intr->intrinsic);
80      chan_intr->num_components = 1;
81
82      nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
83      nir_intrinsic_set_write_mask(chan_intr, 0x1);
84      nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
85      nir_intrinsic_set_src_type(chan_intr, nir_intrinsic_src_type(intr));
86      nir_intrinsic_set_io_semantics(chan_intr, nir_intrinsic_io_semantics(intr));
87
88      /* value */
89      chan_intr->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
90      /* offset */
91      nir_src_copy(&chan_intr->src[1], &intr->src[1]);
92
93      nir_builder_instr_insert(b, &chan_intr->instr);
94   }
95
96   nir_instr_remove(&intr->instr);
97}
98
99static bool
100nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
101{
102   nir_variable_mode mask = *(nir_variable_mode *)data;
103
104   if (instr->type != nir_instr_type_intrinsic)
105      return false;
106
107   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
108
109   if (intr->num_components == 1)
110      return false;
111
112   if (intr->intrinsic == nir_intrinsic_load_input &&
113       (mask & nir_var_shader_in)) {
114      lower_load_input_to_scalar(b, intr);
115      return true;
116   }
117
118   if (intr->intrinsic == nir_intrinsic_store_output &&
119       mask & nir_var_shader_out) {
120      lower_store_output_to_scalar(b, intr);
121      return true;
122   }
123
124   return false;
125}
126
127void
128nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask)
129{
130   nir_shader_instructions_pass(shader,
131                                nir_lower_io_to_scalar_instr,
132                                nir_metadata_block_index |
133                                nir_metadata_dominance,
134                                &mask);
135}
136
137static nir_variable **
138get_channel_variables(struct hash_table *ht, nir_variable *var)
139{
140   nir_variable **chan_vars;
141   struct hash_entry *entry = _mesa_hash_table_search(ht, var);
142   if (!entry) {
143      chan_vars = (nir_variable **) calloc(4, sizeof(nir_variable *));
144      _mesa_hash_table_insert(ht, var, chan_vars);
145   } else {
146      chan_vars = (nir_variable **) entry->data;
147   }
148
149   return chan_vars;
150}
151
152/*
153 * Note that the src deref that we are cloning is the head of the
154 * chain of deref instructions from the original intrinsic, but
155 * the dst we are cloning to is the tail (because chains of deref
156 * instructions are created back to front)
157 */
158
159static nir_deref_instr *
160clone_deref_array(nir_builder *b, nir_deref_instr *dst_tail,
161                  const nir_deref_instr *src_head)
162{
163   const nir_deref_instr *parent = nir_deref_instr_parent(src_head);
164
165   if (!parent)
166      return dst_tail;
167
168   assert(src_head->deref_type == nir_deref_type_array);
169
170   dst_tail = clone_deref_array(b, dst_tail, parent);
171
172   return nir_build_deref_array(b, dst_tail,
173                                nir_ssa_for_src(b, src_head->arr.index, 1));
174}
175
176static void
177lower_load_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
178                           nir_variable *var, struct hash_table *split_inputs,
179                           struct hash_table *split_outputs)
180{
181   b->cursor = nir_before_instr(&intr->instr);
182
183   assert(intr->dest.is_ssa);
184
185   nir_ssa_def *loads[NIR_MAX_VEC_COMPONENTS];
186
187   nir_variable **chan_vars;
188   if (var->data.mode == nir_var_shader_in) {
189      chan_vars = get_channel_variables(split_inputs, var);
190   } else {
191      chan_vars = get_channel_variables(split_outputs, var);
192   }
193
194   for (unsigned i = 0; i < intr->num_components; i++) {
195      nir_variable *chan_var = chan_vars[var->data.location_frac + i];
196      if (!chan_vars[var->data.location_frac + i]) {
197         chan_var = nir_variable_clone(var, b->shader);
198         chan_var->data.location_frac =  var->data.location_frac + i;
199         chan_var->type = glsl_channel_type(chan_var->type);
200         if (var->data.explicit_offset) {
201            unsigned comp_size = glsl_get_bit_size(chan_var->type) / 8;
202            chan_var->data.offset = var->data.offset + i * comp_size;
203         }
204
205         chan_vars[var->data.location_frac + i] = chan_var;
206
207         nir_shader_add_variable(b->shader, chan_var);
208      }
209
210      nir_intrinsic_instr *chan_intr =
211         nir_intrinsic_instr_create(b->shader, intr->intrinsic);
212      nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
213                        1, intr->dest.ssa.bit_size, NULL);
214      chan_intr->num_components = 1;
215
216      nir_deref_instr *deref = nir_build_deref_var(b, chan_var);
217
218      deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
219
220      chan_intr->src[0] = nir_src_for_ssa(&deref->dest.ssa);
221
222      if (intr->intrinsic == nir_intrinsic_interp_deref_at_offset ||
223          intr->intrinsic == nir_intrinsic_interp_deref_at_sample ||
224          intr->intrinsic == nir_intrinsic_interp_deref_at_vertex)
225         nir_src_copy(&chan_intr->src[1], &intr->src[1]);
226
227      nir_builder_instr_insert(b, &chan_intr->instr);
228
229      loads[i] = &chan_intr->dest.ssa;
230   }
231
232   nir_ssa_def_rewrite_uses(&intr->dest.ssa,
233                            nir_vec(b, loads, intr->num_components));
234
235   /* Remove the old load intrinsic */
236   nir_instr_remove(&intr->instr);
237}
238
239static void
240lower_store_output_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
241                                   nir_variable *var,
242                                   struct hash_table *split_outputs)
243{
244   b->cursor = nir_before_instr(&intr->instr);
245
246   nir_ssa_def *value = nir_ssa_for_src(b, intr->src[1], intr->num_components);
247
248   nir_variable **chan_vars = get_channel_variables(split_outputs, var);
249   for (unsigned i = 0; i < intr->num_components; i++) {
250      if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
251         continue;
252
253      nir_variable *chan_var = chan_vars[var->data.location_frac + i];
254      if (!chan_vars[var->data.location_frac + i]) {
255         chan_var = nir_variable_clone(var, b->shader);
256         chan_var->data.location_frac =  var->data.location_frac + i;
257         chan_var->type = glsl_channel_type(chan_var->type);
258         if (var->data.explicit_offset) {
259            unsigned comp_size = glsl_get_bit_size(chan_var->type) / 8;
260            chan_var->data.offset = var->data.offset + i * comp_size;
261         }
262
263         chan_vars[var->data.location_frac + i] = chan_var;
264
265         nir_shader_add_variable(b->shader, chan_var);
266      }
267
268      nir_intrinsic_instr *chan_intr =
269         nir_intrinsic_instr_create(b->shader, intr->intrinsic);
270      chan_intr->num_components = 1;
271
272      nir_intrinsic_set_write_mask(chan_intr, 0x1);
273
274      nir_deref_instr *deref = nir_build_deref_var(b, chan_var);
275
276      deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
277
278      chan_intr->src[0] = nir_src_for_ssa(&deref->dest.ssa);
279      chan_intr->src[1] = nir_src_for_ssa(nir_channel(b, value, i));
280
281      nir_builder_instr_insert(b, &chan_intr->instr);
282   }
283
284   /* Remove the old store intrinsic */
285   nir_instr_remove(&intr->instr);
286}
287
288struct io_to_scalar_early_state {
289   struct hash_table *split_inputs, *split_outputs;
290   nir_variable_mode mask;
291};
292
293static bool
294nir_lower_io_to_scalar_early_instr(nir_builder *b, nir_instr *instr, void *data)
295{
296   struct io_to_scalar_early_state *state = data;
297
298   if (instr->type != nir_instr_type_intrinsic)
299      return false;
300
301   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
302
303   if (intr->num_components == 1)
304      return false;
305
306   if (intr->intrinsic != nir_intrinsic_load_deref &&
307       intr->intrinsic != nir_intrinsic_store_deref &&
308       intr->intrinsic != nir_intrinsic_interp_deref_at_centroid &&
309       intr->intrinsic != nir_intrinsic_interp_deref_at_sample &&
310       intr->intrinsic != nir_intrinsic_interp_deref_at_offset &&
311       intr->intrinsic != nir_intrinsic_interp_deref_at_vertex)
312      return false;
313
314   nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
315   if (!nir_deref_mode_is_one_of(deref, state->mask))
316      return false;
317
318   nir_variable *var = nir_deref_instr_get_variable(deref);
319   nir_variable_mode mode = var->data.mode;
320
321   /* TODO: add patch support */
322   if (var->data.patch)
323      return false;
324
325   /* TODO: add doubles support */
326   if (glsl_type_is_64bit(glsl_without_array(var->type)))
327      return false;
328
329   if (!(b->shader->info.stage == MESA_SHADER_VERTEX &&
330         mode == nir_var_shader_in) &&
331       var->data.location < VARYING_SLOT_VAR0 &&
332       var->data.location >= 0)
333      return false;
334
335   /* Don't bother splitting if we can't opt away any unused
336    * components.
337    */
338   if (var->data.always_active_io)
339      return false;
340
341   /* Skip types we cannot split */
342   if (glsl_type_is_matrix(glsl_without_array(var->type)) ||
343       glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
344      return false;
345
346   switch (intr->intrinsic) {
347   case nir_intrinsic_interp_deref_at_centroid:
348   case nir_intrinsic_interp_deref_at_sample:
349   case nir_intrinsic_interp_deref_at_offset:
350   case nir_intrinsic_interp_deref_at_vertex:
351   case nir_intrinsic_load_deref:
352      if ((state->mask & nir_var_shader_in && mode == nir_var_shader_in) ||
353          (state->mask & nir_var_shader_out && mode == nir_var_shader_out)) {
354         lower_load_to_scalar_early(b, intr, var, state->split_inputs,
355                                    state->split_outputs);
356         return true;
357      }
358      break;
359   case nir_intrinsic_store_deref:
360      if (state->mask & nir_var_shader_out &&
361          mode == nir_var_shader_out) {
362         lower_store_output_to_scalar_early(b, intr, var, state->split_outputs);
363         return true;
364      }
365      break;
366   default:
367      break;
368   }
369
370   return false;
371}
372
373/*
374 * This function is intended to be called earlier than nir_lower_io_to_scalar()
375 * i.e. before nir_lower_io() is called.
376 */
377bool
378nir_lower_io_to_scalar_early(nir_shader *shader, nir_variable_mode mask)
379{
380   struct io_to_scalar_early_state state = {
381      .split_inputs = _mesa_pointer_hash_table_create(NULL),
382      .split_outputs = _mesa_pointer_hash_table_create(NULL),
383      .mask = mask
384   };
385
386   bool progress = nir_shader_instructions_pass(shader,
387                                                nir_lower_io_to_scalar_early_instr,
388                                                nir_metadata_block_index |
389                                                nir_metadata_dominance,
390                                                &state);
391
392   /* Remove old input from the shaders inputs list */
393   hash_table_foreach(state.split_inputs, entry) {
394      nir_variable *var = (nir_variable *) entry->key;
395      exec_node_remove(&var->node);
396
397      free(entry->data);
398   }
399
400   /* Remove old output from the shaders outputs list */
401   hash_table_foreach(state.split_outputs, entry) {
402      nir_variable *var = (nir_variable *) entry->key;
403      exec_node_remove(&var->node);
404
405      free(entry->data);
406   }
407
408   _mesa_hash_table_destroy(state.split_inputs, NULL);
409   _mesa_hash_table_destroy(state.split_outputs, NULL);
410
411   nir_remove_dead_derefs(shader);
412
413   return progress;
414}
415