vtn_cfg.c revision 7e102996
1/*
2 * Copyright © 2015 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "vtn_private.h"
25#include "nir/nir_vla.h"
26
27static struct vtn_pointer *
28vtn_load_param_pointer(struct vtn_builder *b,
29                       struct vtn_type *param_type,
30                       uint32_t param_idx)
31{
32   struct vtn_type *ptr_type = param_type;
33   if (param_type->base_type != vtn_base_type_pointer) {
34      assert(param_type->base_type == vtn_base_type_image ||
35             param_type->base_type == vtn_base_type_sampler);
36      ptr_type = rzalloc(b, struct vtn_type);
37      ptr_type->base_type = vtn_base_type_pointer;
38      ptr_type->deref = param_type;
39      ptr_type->storage_class = SpvStorageClassUniformConstant;
40   }
41
42   return vtn_pointer_from_ssa(b, nir_load_param(&b->nb, param_idx), ptr_type);
43}
44
45static unsigned
46vtn_type_count_function_params(struct vtn_type *type)
47{
48   switch (type->base_type) {
49   case vtn_base_type_array:
50   case vtn_base_type_matrix:
51      return type->length * vtn_type_count_function_params(type->array_element);
52
53   case vtn_base_type_struct: {
54      unsigned count = 0;
55      for (unsigned i = 0; i < type->length; i++)
56         count += vtn_type_count_function_params(type->members[i]);
57      return count;
58   }
59
60   case vtn_base_type_sampled_image:
61      return 2;
62
63   default:
64      return 1;
65   }
66}
67
68static void
69vtn_type_add_to_function_params(struct vtn_type *type,
70                                nir_function *func,
71                                unsigned *param_idx)
72{
73   static const nir_parameter nir_deref_param = {
74      .num_components = 1,
75      .bit_size = 32,
76   };
77
78   switch (type->base_type) {
79   case vtn_base_type_array:
80   case vtn_base_type_matrix:
81      for (unsigned i = 0; i < type->length; i++)
82         vtn_type_add_to_function_params(type->array_element, func, param_idx);
83      break;
84
85   case vtn_base_type_struct:
86      for (unsigned i = 0; i < type->length; i++)
87         vtn_type_add_to_function_params(type->members[i], func, param_idx);
88      break;
89
90   case vtn_base_type_sampled_image:
91      func->params[(*param_idx)++] = nir_deref_param;
92      func->params[(*param_idx)++] = nir_deref_param;
93      break;
94
95   case vtn_base_type_image:
96   case vtn_base_type_sampler:
97      func->params[(*param_idx)++] = nir_deref_param;
98      break;
99
100   case vtn_base_type_pointer:
101      if (type->type) {
102         func->params[(*param_idx)++] = (nir_parameter) {
103            .num_components = glsl_get_vector_elements(type->type),
104            .bit_size = glsl_get_bit_size(type->type),
105         };
106      } else {
107         func->params[(*param_idx)++] = nir_deref_param;
108      }
109      break;
110
111   default:
112      func->params[(*param_idx)++] = (nir_parameter) {
113         .num_components = glsl_get_vector_elements(type->type),
114         .bit_size = glsl_get_bit_size(type->type),
115      };
116   }
117}
118
119static void
120vtn_ssa_value_add_to_call_params(struct vtn_builder *b,
121                                 struct vtn_ssa_value *value,
122                                 struct vtn_type *type,
123                                 nir_call_instr *call,
124                                 unsigned *param_idx)
125{
126   switch (type->base_type) {
127   case vtn_base_type_array:
128   case vtn_base_type_matrix:
129      for (unsigned i = 0; i < type->length; i++) {
130         vtn_ssa_value_add_to_call_params(b, value->elems[i],
131                                          type->array_element,
132                                          call, param_idx);
133      }
134      break;
135
136   case vtn_base_type_struct:
137      for (unsigned i = 0; i < type->length; i++) {
138         vtn_ssa_value_add_to_call_params(b, value->elems[i],
139                                          type->members[i],
140                                          call, param_idx);
141      }
142      break;
143
144   default:
145      call->params[(*param_idx)++] = nir_src_for_ssa(value->def);
146      break;
147   }
148}
149
150static void
151vtn_ssa_value_load_function_param(struct vtn_builder *b,
152                                  struct vtn_ssa_value *value,
153                                  struct vtn_type *type,
154                                  unsigned *param_idx)
155{
156   switch (type->base_type) {
157   case vtn_base_type_array:
158   case vtn_base_type_matrix:
159      for (unsigned i = 0; i < type->length; i++) {
160         vtn_ssa_value_load_function_param(b, value->elems[i],
161                                           type->array_element, param_idx);
162      }
163      break;
164
165   case vtn_base_type_struct:
166      for (unsigned i = 0; i < type->length; i++) {
167         vtn_ssa_value_load_function_param(b, value->elems[i],
168                                           type->members[i], param_idx);
169      }
170      break;
171
172   default:
173      value->def = nir_load_param(&b->nb, (*param_idx)++);
174      break;
175   }
176}
177
178void
179vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
180                         const uint32_t *w, unsigned count)
181{
182   struct vtn_type *res_type = vtn_value(b, w[1], vtn_value_type_type)->type;
183   struct vtn_function *vtn_callee =
184      vtn_value(b, w[3], vtn_value_type_function)->func;
185   struct nir_function *callee = vtn_callee->impl->function;
186
187   vtn_callee->referenced = true;
188
189   nir_call_instr *call = nir_call_instr_create(b->nb.shader, callee);
190
191   unsigned param_idx = 0;
192
193   nir_deref_instr *ret_deref = NULL;
194   struct vtn_type *ret_type = vtn_callee->type->return_type;
195   if (ret_type->base_type != vtn_base_type_void) {
196      nir_variable *ret_tmp =
197         nir_local_variable_create(b->nb.impl,
198                                   glsl_get_bare_type(ret_type->type),
199                                   "return_tmp");
200      ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
201      call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
202   }
203
204   for (unsigned i = 0; i < vtn_callee->type->length; i++) {
205      struct vtn_type *arg_type = vtn_callee->type->params[i];
206      unsigned arg_id = w[4 + i];
207
208      if (arg_type->base_type == vtn_base_type_sampled_image) {
209         struct vtn_sampled_image *sampled_image =
210            vtn_value(b, arg_id, vtn_value_type_sampled_image)->sampled_image;
211
212         call->params[param_idx++] =
213            nir_src_for_ssa(&sampled_image->image->deref->dest.ssa);
214         call->params[param_idx++] =
215            nir_src_for_ssa(&sampled_image->sampler->deref->dest.ssa);
216      } else if (arg_type->base_type == vtn_base_type_pointer ||
217                 arg_type->base_type == vtn_base_type_image ||
218                 arg_type->base_type == vtn_base_type_sampler) {
219         struct vtn_pointer *pointer =
220            vtn_value(b, arg_id, vtn_value_type_pointer)->pointer;
221         call->params[param_idx++] =
222            nir_src_for_ssa(vtn_pointer_to_ssa(b, pointer));
223      } else {
224         vtn_ssa_value_add_to_call_params(b, vtn_ssa_value(b, arg_id),
225                                          arg_type, call, &param_idx);
226      }
227   }
228   assert(param_idx == call->num_params);
229
230   nir_builder_instr_insert(&b->nb, &call->instr);
231
232   if (ret_type->base_type == vtn_base_type_void) {
233      vtn_push_value(b, w[2], vtn_value_type_undef);
234   } else {
235      vtn_push_ssa(b, w[2], res_type, vtn_local_load(b, ret_deref, 0));
236   }
237}
238
239static bool
240vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
241                                   const uint32_t *w, unsigned count)
242{
243   switch (opcode) {
244   case SpvOpFunction: {
245      vtn_assert(b->func == NULL);
246      b->func = rzalloc(b, struct vtn_function);
247
248      list_inithead(&b->func->body);
249      b->func->control = w[3];
250
251      MAYBE_UNUSED const struct glsl_type *result_type =
252         vtn_value(b, w[1], vtn_value_type_type)->type->type;
253      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
254      val->func = b->func;
255
256      b->func->type = vtn_value(b, w[4], vtn_value_type_type)->type;
257      const struct vtn_type *func_type = b->func->type;
258
259      vtn_assert(func_type->return_type->type == result_type);
260
261      nir_function *func =
262         nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
263
264      unsigned num_params = 0;
265      for (unsigned i = 0; i < func_type->length; i++)
266         num_params += vtn_type_count_function_params(func_type->params[i]);
267
268      /* Add one parameter for the function return value */
269      if (func_type->return_type->base_type != vtn_base_type_void)
270         num_params++;
271
272      func->num_params = num_params;
273      func->params = ralloc_array(b->shader, nir_parameter, num_params);
274
275      unsigned idx = 0;
276      if (func_type->return_type->base_type != vtn_base_type_void) {
277         /* The return value is a regular pointer */
278         func->params[idx++] = (nir_parameter) {
279            .num_components = 1, .bit_size = 32,
280         };
281      }
282
283      for (unsigned i = 0; i < func_type->length; i++)
284         vtn_type_add_to_function_params(func_type->params[i], func, &idx);
285      assert(idx == num_params);
286
287      b->func->impl = nir_function_impl_create(func);
288      nir_builder_init(&b->nb, func->impl);
289      b->nb.cursor = nir_before_cf_list(&b->func->impl->body);
290      b->nb.exact = b->exact;
291
292      b->func_param_idx = 0;
293
294      /* The return value is the first parameter */
295      if (func_type->return_type->base_type != vtn_base_type_void)
296         b->func_param_idx++;
297      break;
298   }
299
300   case SpvOpFunctionEnd:
301      b->func->end = w;
302      b->func = NULL;
303      break;
304
305   case SpvOpFunctionParameter: {
306      struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
307
308      vtn_assert(b->func_param_idx < b->func->impl->function->num_params);
309
310      if (type->base_type == vtn_base_type_sampled_image) {
311         /* Sampled images are actually two parameters.  The first is the
312          * image and the second is the sampler.
313          */
314         struct vtn_value *val =
315            vtn_push_value(b, w[2], vtn_value_type_sampled_image);
316
317         val->sampled_image = ralloc(b, struct vtn_sampled_image);
318         val->sampled_image->type = type;
319
320         struct vtn_type *sampler_type = rzalloc(b, struct vtn_type);
321         sampler_type->base_type = vtn_base_type_sampler;
322         sampler_type->type = glsl_bare_sampler_type();
323
324         val->sampled_image->image =
325            vtn_load_param_pointer(b, type, b->func_param_idx++);
326         val->sampled_image->sampler =
327            vtn_load_param_pointer(b, sampler_type, b->func_param_idx++);
328      } else if (type->base_type == vtn_base_type_pointer &&
329                 type->type != NULL) {
330         /* This is a pointer with an actual storage type */
331         nir_ssa_def *ssa_ptr = nir_load_param(&b->nb, b->func_param_idx++);
332         vtn_push_value_pointer(b, w[2], vtn_pointer_from_ssa(b, ssa_ptr, type));
333      } else if (type->base_type == vtn_base_type_pointer ||
334                 type->base_type == vtn_base_type_image ||
335                 type->base_type == vtn_base_type_sampler) {
336         vtn_push_value_pointer(b, w[2], vtn_load_param_pointer(b, type, b->func_param_idx++));
337      } else {
338         /* We're a regular SSA value. */
339         struct vtn_ssa_value *value = vtn_create_ssa_value(b, type->type);
340         vtn_ssa_value_load_function_param(b, value, type, &b->func_param_idx);
341         vtn_push_ssa(b, w[2], type, value);
342      }
343      break;
344   }
345
346   case SpvOpLabel: {
347      vtn_assert(b->block == NULL);
348      b->block = rzalloc(b, struct vtn_block);
349      b->block->node.type = vtn_cf_node_type_block;
350      b->block->label = w;
351      vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
352
353      if (b->func->start_block == NULL) {
354         /* This is the first block encountered for this function.  In this
355          * case, we set the start block and add it to the list of
356          * implemented functions that we'll walk later.
357          */
358         b->func->start_block = b->block;
359         exec_list_push_tail(&b->functions, &b->func->node);
360      }
361      break;
362   }
363
364   case SpvOpSelectionMerge:
365   case SpvOpLoopMerge:
366      vtn_assert(b->block && b->block->merge == NULL);
367      b->block->merge = w;
368      break;
369
370   case SpvOpBranch:
371   case SpvOpBranchConditional:
372   case SpvOpSwitch:
373   case SpvOpKill:
374   case SpvOpReturn:
375   case SpvOpReturnValue:
376   case SpvOpUnreachable:
377      vtn_assert(b->block && b->block->branch == NULL);
378      b->block->branch = w;
379      b->block = NULL;
380      break;
381
382   default:
383      /* Continue on as per normal */
384      return true;
385   }
386
387   return true;
388}
389
390static void
391vtn_add_case(struct vtn_builder *b, struct vtn_switch *swtch,
392             struct vtn_block *break_block,
393             uint32_t block_id, uint64_t val, bool is_default)
394{
395   struct vtn_block *case_block =
396      vtn_value(b, block_id, vtn_value_type_block)->block;
397
398   /* Don't create dummy cases that just break */
399   if (case_block == break_block)
400      return;
401
402   if (case_block->switch_case == NULL) {
403      struct vtn_case *c = ralloc(b, struct vtn_case);
404
405      list_inithead(&c->body);
406      c->start_block = case_block;
407      c->fallthrough = NULL;
408      util_dynarray_init(&c->values, b);
409      c->is_default = false;
410      c->visited = false;
411
412      list_addtail(&c->link, &swtch->cases);
413
414      case_block->switch_case = c;
415   }
416
417   if (is_default) {
418      case_block->switch_case->is_default = true;
419   } else {
420      util_dynarray_append(&case_block->switch_case->values, uint64_t, val);
421   }
422}
423
424/* This function performs a depth-first search of the cases and puts them
425 * in fall-through order.
426 */
427static void
428vtn_order_case(struct vtn_switch *swtch, struct vtn_case *cse)
429{
430   if (cse->visited)
431      return;
432
433   cse->visited = true;
434
435   list_del(&cse->link);
436
437   if (cse->fallthrough) {
438      vtn_order_case(swtch, cse->fallthrough);
439
440      /* If we have a fall-through, place this case right before the case it
441       * falls through to.  This ensures that fallthroughs come one after
442       * the other.  These two can never get separated because that would
443       * imply something else falling through to the same case.  Also, this
444       * can't break ordering because the DFS ensures that this case is
445       * visited before anything that falls through to it.
446       */
447      list_addtail(&cse->link, &cse->fallthrough->link);
448   } else {
449      list_add(&cse->link, &swtch->cases);
450   }
451}
452
453static enum vtn_branch_type
454vtn_get_branch_type(struct vtn_builder *b,
455                    struct vtn_block *block,
456                    struct vtn_case *swcase, struct vtn_block *switch_break,
457                    struct vtn_block *loop_break, struct vtn_block *loop_cont)
458{
459   if (block->switch_case) {
460      /* This branch is actually a fallthrough */
461      vtn_assert(swcase->fallthrough == NULL ||
462                 swcase->fallthrough == block->switch_case);
463      swcase->fallthrough = block->switch_case;
464      return vtn_branch_type_switch_fallthrough;
465   } else if (block == loop_break) {
466      return vtn_branch_type_loop_break;
467   } else if (block == loop_cont) {
468      return vtn_branch_type_loop_continue;
469   } else if (block == switch_break) {
470      return vtn_branch_type_switch_break;
471   } else {
472      return vtn_branch_type_none;
473   }
474}
475
476static void
477vtn_cfg_walk_blocks(struct vtn_builder *b, struct list_head *cf_list,
478                    struct vtn_block *start, struct vtn_case *switch_case,
479                    struct vtn_block *switch_break,
480                    struct vtn_block *loop_break, struct vtn_block *loop_cont,
481                    struct vtn_block *end)
482{
483   struct vtn_block *block = start;
484   while (block != end) {
485      if (block->merge && (*block->merge & SpvOpCodeMask) == SpvOpLoopMerge &&
486          !block->loop) {
487         struct vtn_loop *loop = ralloc(b, struct vtn_loop);
488
489         loop->node.type = vtn_cf_node_type_loop;
490         list_inithead(&loop->body);
491         list_inithead(&loop->cont_body);
492         loop->control = block->merge[3];
493
494         list_addtail(&loop->node.link, cf_list);
495         block->loop = loop;
496
497         struct vtn_block *new_loop_break =
498            vtn_value(b, block->merge[1], vtn_value_type_block)->block;
499         struct vtn_block *new_loop_cont =
500            vtn_value(b, block->merge[2], vtn_value_type_block)->block;
501
502         /* Note: This recursive call will start with the current block as
503          * its start block.  If we weren't careful, we would get here
504          * again and end up in infinite recursion.  This is why we set
505          * block->loop above and check for it before creating one.  This
506          * way, we only create the loop once and the second call that
507          * tries to handle this loop goes to the cases below and gets
508          * handled as a regular block.
509          *
510          * Note: When we make the recursive walk calls, we pass NULL for
511          * the switch break since you have to break out of the loop first.
512          * We do, however, still pass the current switch case because it's
513          * possible that the merge block for the loop is the start of
514          * another case.
515          */
516         vtn_cfg_walk_blocks(b, &loop->body, block, switch_case, NULL,
517                             new_loop_break, new_loop_cont, NULL );
518         vtn_cfg_walk_blocks(b, &loop->cont_body, new_loop_cont, NULL, NULL,
519                             new_loop_break, NULL, block);
520
521         enum vtn_branch_type branch_type =
522            vtn_get_branch_type(b, new_loop_break, switch_case, switch_break,
523                                loop_break, loop_cont);
524
525         if (branch_type != vtn_branch_type_none) {
526            /* Stop walking through the CFG when this inner loop's break block
527             * ends up as the same block as the outer loop's continue block
528             * because we are already going to visit it.
529             */
530            vtn_assert(branch_type == vtn_branch_type_loop_continue);
531            return;
532         }
533
534         block = new_loop_break;
535         continue;
536      }
537
538      vtn_assert(block->node.link.next == NULL);
539      list_addtail(&block->node.link, cf_list);
540
541      switch (*block->branch & SpvOpCodeMask) {
542      case SpvOpBranch: {
543         struct vtn_block *branch_block =
544            vtn_value(b, block->branch[1], vtn_value_type_block)->block;
545
546         block->branch_type = vtn_get_branch_type(b, branch_block,
547                                                  switch_case, switch_break,
548                                                  loop_break, loop_cont);
549
550         if (block->branch_type != vtn_branch_type_none)
551            return;
552
553         block = branch_block;
554         continue;
555      }
556
557      case SpvOpReturn:
558      case SpvOpReturnValue:
559         block->branch_type = vtn_branch_type_return;
560         return;
561
562      case SpvOpKill:
563         block->branch_type = vtn_branch_type_discard;
564         return;
565
566      case SpvOpBranchConditional: {
567         struct vtn_block *then_block =
568            vtn_value(b, block->branch[2], vtn_value_type_block)->block;
569         struct vtn_block *else_block =
570            vtn_value(b, block->branch[3], vtn_value_type_block)->block;
571
572         struct vtn_if *if_stmt = ralloc(b, struct vtn_if);
573
574         if_stmt->node.type = vtn_cf_node_type_if;
575         if_stmt->condition = block->branch[1];
576         list_inithead(&if_stmt->then_body);
577         list_inithead(&if_stmt->else_body);
578
579         list_addtail(&if_stmt->node.link, cf_list);
580
581         if (block->merge &&
582             (*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge) {
583            if_stmt->control = block->merge[2];
584         } else {
585            if_stmt->control = SpvSelectionControlMaskNone;
586         }
587
588         if_stmt->then_type = vtn_get_branch_type(b, then_block,
589                                                  switch_case, switch_break,
590                                                  loop_break, loop_cont);
591         if_stmt->else_type = vtn_get_branch_type(b, else_block,
592                                                  switch_case, switch_break,
593                                                  loop_break, loop_cont);
594
595         if (then_block == else_block) {
596            block->branch_type = if_stmt->then_type;
597            if (block->branch_type == vtn_branch_type_none) {
598               block = then_block;
599               continue;
600            } else {
601               return;
602            }
603         } else if (if_stmt->then_type == vtn_branch_type_none &&
604                    if_stmt->else_type == vtn_branch_type_none) {
605            /* Neither side of the if is something we can short-circuit. */
606            vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
607            struct vtn_block *merge_block =
608               vtn_value(b, block->merge[1], vtn_value_type_block)->block;
609
610            vtn_cfg_walk_blocks(b, &if_stmt->then_body, then_block,
611                                switch_case, switch_break,
612                                loop_break, loop_cont, merge_block);
613            vtn_cfg_walk_blocks(b, &if_stmt->else_body, else_block,
614                                switch_case, switch_break,
615                                loop_break, loop_cont, merge_block);
616
617            enum vtn_branch_type merge_type =
618               vtn_get_branch_type(b, merge_block, switch_case, switch_break,
619                                   loop_break, loop_cont);
620            if (merge_type == vtn_branch_type_none) {
621               block = merge_block;
622               continue;
623            } else {
624               return;
625            }
626         } else if (if_stmt->then_type != vtn_branch_type_none &&
627                    if_stmt->else_type != vtn_branch_type_none) {
628            /* Both sides were short-circuited.  We're done here. */
629            return;
630         } else {
631            /* Exeactly one side of the branch could be short-circuited.
632             * We set the branch up as a predicated break/continue and we
633             * continue on with the other side as if it were what comes
634             * after the if.
635             */
636            if (if_stmt->then_type == vtn_branch_type_none) {
637               block = then_block;
638            } else {
639               block = else_block;
640            }
641            continue;
642         }
643         vtn_fail("Should have returned or continued");
644      }
645
646      case SpvOpSwitch: {
647         vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
648         struct vtn_block *break_block =
649            vtn_value(b, block->merge[1], vtn_value_type_block)->block;
650
651         struct vtn_switch *swtch = ralloc(b, struct vtn_switch);
652
653         swtch->node.type = vtn_cf_node_type_switch;
654         swtch->selector = block->branch[1];
655         list_inithead(&swtch->cases);
656
657         list_addtail(&swtch->node.link, cf_list);
658
659         /* First, we go through and record all of the cases. */
660         const uint32_t *branch_end =
661            block->branch + (block->branch[0] >> SpvWordCountShift);
662
663         struct vtn_value *cond_val = vtn_untyped_value(b, block->branch[1]);
664         vtn_fail_if(!cond_val->type ||
665                     cond_val->type->base_type != vtn_base_type_scalar,
666                     "Selector of OpSelect must have a type of OpTypeInt");
667
668         nir_alu_type cond_type =
669            nir_get_nir_type_for_glsl_type(cond_val->type->type);
670         vtn_fail_if(nir_alu_type_get_base_type(cond_type) != nir_type_int &&
671                     nir_alu_type_get_base_type(cond_type) != nir_type_uint,
672                     "Selector of OpSelect must have a type of OpTypeInt");
673
674         bool is_default = true;
675         const unsigned bitsize = nir_alu_type_get_type_size(cond_type);
676         for (const uint32_t *w = block->branch + 2; w < branch_end;) {
677            uint64_t literal = 0;
678            if (!is_default) {
679               if (bitsize <= 32) {
680                  literal = *(w++);
681               } else {
682                  assert(bitsize == 64);
683                  literal = vtn_u64_literal(w);
684                  w += 2;
685               }
686            }
687
688            uint32_t block_id = *(w++);
689
690            vtn_add_case(b, swtch, break_block, block_id, literal, is_default);
691            is_default = false;
692         }
693
694         /* Now, we go through and walk the blocks.  While we walk through
695          * the blocks, we also gather the much-needed fall-through
696          * information.
697          */
698         list_for_each_entry(struct vtn_case, cse, &swtch->cases, link) {
699            vtn_assert(cse->start_block != break_block);
700            vtn_cfg_walk_blocks(b, &cse->body, cse->start_block, cse,
701                                break_block, loop_break, loop_cont, NULL);
702         }
703
704         /* Finally, we walk over all of the cases one more time and put
705          * them in fall-through order.
706          */
707         for (const uint32_t *w = block->branch + 2; w < branch_end;) {
708            struct vtn_block *case_block =
709               vtn_value(b, *w, vtn_value_type_block)->block;
710
711            if (bitsize <= 32) {
712               w += 2;
713            } else {
714               assert(bitsize == 64);
715               w += 3;
716            }
717
718            if (case_block == break_block)
719               continue;
720
721            vtn_assert(case_block->switch_case);
722
723            vtn_order_case(swtch, case_block->switch_case);
724         }
725
726         enum vtn_branch_type branch_type =
727            vtn_get_branch_type(b, break_block, switch_case, NULL,
728                                loop_break, loop_cont);
729
730         if (branch_type != vtn_branch_type_none) {
731            /* It is possible that the break is actually the continue block
732             * for the containing loop.  In this case, we need to bail and let
733             * the loop parsing code handle the continue properly.
734             */
735            vtn_assert(branch_type == vtn_branch_type_loop_continue);
736            return;
737         }
738
739         block = break_block;
740         continue;
741      }
742
743      case SpvOpUnreachable:
744         return;
745
746      default:
747         vtn_fail("Unhandled opcode");
748      }
749   }
750}
751
752void
753vtn_build_cfg(struct vtn_builder *b, const uint32_t *words, const uint32_t *end)
754{
755   vtn_foreach_instruction(b, words, end,
756                           vtn_cfg_handle_prepass_instruction);
757
758   foreach_list_typed(struct vtn_function, func, node, &b->functions) {
759      vtn_cfg_walk_blocks(b, &func->body, func->start_block,
760                          NULL, NULL, NULL, NULL, NULL);
761   }
762}
763
764static bool
765vtn_handle_phis_first_pass(struct vtn_builder *b, SpvOp opcode,
766                           const uint32_t *w, unsigned count)
767{
768   if (opcode == SpvOpLabel)
769      return true; /* Nothing to do */
770
771   /* If this isn't a phi node, stop. */
772   if (opcode != SpvOpPhi)
773      return false;
774
775   /* For handling phi nodes, we do a poor-man's out-of-ssa on the spot.
776    * For each phi, we create a variable with the appropreate type and
777    * do a load from that variable.  Then, in a second pass, we add
778    * stores to that variable to each of the predecessor blocks.
779    *
780    * We could do something more intelligent here.  However, in order to
781    * handle loops and things properly, we really need dominance
782    * information.  It would end up basically being the into-SSA
783    * algorithm all over again.  It's easier if we just let
784    * lower_vars_to_ssa do that for us instead of repeating it here.
785    */
786   struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
787   nir_variable *phi_var =
788      nir_local_variable_create(b->nb.impl, type->type, "phi");
789   _mesa_hash_table_insert(b->phi_table, w, phi_var);
790
791   vtn_push_ssa(b, w[2], type,
792                vtn_local_load(b, nir_build_deref_var(&b->nb, phi_var), 0));
793
794   return true;
795}
796
797static bool
798vtn_handle_phi_second_pass(struct vtn_builder *b, SpvOp opcode,
799                           const uint32_t *w, unsigned count)
800{
801   if (opcode != SpvOpPhi)
802      return true;
803
804   struct hash_entry *phi_entry = _mesa_hash_table_search(b->phi_table, w);
805   vtn_assert(phi_entry);
806   nir_variable *phi_var = phi_entry->data;
807
808   for (unsigned i = 3; i < count; i += 2) {
809      struct vtn_block *pred =
810         vtn_value(b, w[i + 1], vtn_value_type_block)->block;
811
812      b->nb.cursor = nir_after_instr(&pred->end_nop->instr);
813
814      struct vtn_ssa_value *src = vtn_ssa_value(b, w[i]);
815
816      vtn_local_store(b, src, nir_build_deref_var(&b->nb, phi_var), 0);
817   }
818
819   return true;
820}
821
822static void
823vtn_emit_branch(struct vtn_builder *b, enum vtn_branch_type branch_type,
824                nir_variable *switch_fall_var, bool *has_switch_break)
825{
826   switch (branch_type) {
827   case vtn_branch_type_switch_break:
828      nir_store_var(&b->nb, switch_fall_var, nir_imm_false(&b->nb), 1);
829      *has_switch_break = true;
830      break;
831   case vtn_branch_type_switch_fallthrough:
832      break; /* Nothing to do */
833   case vtn_branch_type_loop_break:
834      nir_jump(&b->nb, nir_jump_break);
835      break;
836   case vtn_branch_type_loop_continue:
837      nir_jump(&b->nb, nir_jump_continue);
838      break;
839   case vtn_branch_type_return:
840      nir_jump(&b->nb, nir_jump_return);
841      break;
842   case vtn_branch_type_discard: {
843      nir_intrinsic_instr *discard =
844         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_discard);
845      nir_builder_instr_insert(&b->nb, &discard->instr);
846      break;
847   }
848   default:
849      vtn_fail("Invalid branch type");
850   }
851}
852
853static nir_ssa_def *
854vtn_switch_case_condition(struct vtn_builder *b, struct vtn_switch *swtch,
855                          nir_ssa_def *sel, struct vtn_case *cse)
856{
857   if (cse->is_default) {
858      nir_ssa_def *any = nir_imm_false(&b->nb);
859      list_for_each_entry(struct vtn_case, other, &swtch->cases, link) {
860         if (other->is_default)
861            continue;
862
863         any = nir_ior(&b->nb, any,
864                       vtn_switch_case_condition(b, swtch, sel, other));
865      }
866      return nir_inot(&b->nb, any);
867   } else {
868      nir_ssa_def *cond = nir_imm_false(&b->nb);
869      util_dynarray_foreach(&cse->values, uint64_t, val) {
870         nir_ssa_def *imm = nir_imm_intN_t(&b->nb, *val, sel->bit_size);
871         cond = nir_ior(&b->nb, cond, nir_ieq(&b->nb, sel, imm));
872      }
873      return cond;
874   }
875}
876
877static nir_loop_control
878vtn_loop_control(struct vtn_builder *b, struct vtn_loop *vtn_loop)
879{
880   if (vtn_loop->control == SpvLoopControlMaskNone)
881      return nir_loop_control_none;
882   else if (vtn_loop->control & SpvLoopControlDontUnrollMask)
883      return nir_loop_control_dont_unroll;
884   else if (vtn_loop->control & SpvLoopControlUnrollMask)
885      return nir_loop_control_unroll;
886   else if (vtn_loop->control & SpvLoopControlDependencyInfiniteMask ||
887            vtn_loop->control & SpvLoopControlDependencyLengthMask) {
888      /* We do not do anything special with these yet. */
889      return nir_loop_control_none;
890   } else {
891      vtn_fail("Invalid loop control");
892   }
893}
894
895static nir_selection_control
896vtn_selection_control(struct vtn_builder *b, struct vtn_if *vtn_if)
897{
898   if (vtn_if->control == SpvSelectionControlMaskNone)
899      return nir_selection_control_none;
900   else if (vtn_if->control & SpvSelectionControlDontFlattenMask)
901      return nir_selection_control_dont_flatten;
902   else if (vtn_if->control & SpvSelectionControlFlattenMask)
903      return nir_selection_control_flatten;
904   else
905      vtn_fail("Invalid selection control");
906}
907
908static void
909vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
910                 nir_variable *switch_fall_var, bool *has_switch_break,
911                 vtn_instruction_handler handler)
912{
913   list_for_each_entry(struct vtn_cf_node, node, cf_list, link) {
914      switch (node->type) {
915      case vtn_cf_node_type_block: {
916         struct vtn_block *block = (struct vtn_block *)node;
917
918         const uint32_t *block_start = block->label;
919         const uint32_t *block_end = block->merge ? block->merge :
920                                                    block->branch;
921
922         block_start = vtn_foreach_instruction(b, block_start, block_end,
923                                               vtn_handle_phis_first_pass);
924
925         vtn_foreach_instruction(b, block_start, block_end, handler);
926
927         block->end_nop = nir_intrinsic_instr_create(b->nb.shader,
928                                                     nir_intrinsic_nop);
929         nir_builder_instr_insert(&b->nb, &block->end_nop->instr);
930
931         if ((*block->branch & SpvOpCodeMask) == SpvOpReturnValue) {
932            vtn_fail_if(b->func->type->return_type->base_type ==
933                        vtn_base_type_void,
934                        "Return with a value from a function returning void");
935            struct vtn_ssa_value *src = vtn_ssa_value(b, block->branch[1]);
936            const struct glsl_type *ret_type =
937               glsl_get_bare_type(b->func->type->return_type->type);
938            nir_deref_instr *ret_deref =
939               nir_build_deref_cast(&b->nb, nir_load_param(&b->nb, 0),
940                                    nir_var_function_temp, ret_type, 0);
941            vtn_local_store(b, src, ret_deref, 0);
942         }
943
944         if (block->branch_type != vtn_branch_type_none) {
945            vtn_emit_branch(b, block->branch_type,
946                            switch_fall_var, has_switch_break);
947            return;
948         }
949
950         break;
951      }
952
953      case vtn_cf_node_type_if: {
954         struct vtn_if *vtn_if = (struct vtn_if *)node;
955         bool sw_break = false;
956
957         nir_if *nif =
958            nir_push_if(&b->nb, vtn_ssa_value(b, vtn_if->condition)->def);
959
960         nif->control = vtn_selection_control(b, vtn_if);
961
962         if (vtn_if->then_type == vtn_branch_type_none) {
963            vtn_emit_cf_list(b, &vtn_if->then_body,
964                             switch_fall_var, &sw_break, handler);
965         } else {
966            vtn_emit_branch(b, vtn_if->then_type, switch_fall_var, &sw_break);
967         }
968
969         nir_push_else(&b->nb, nif);
970         if (vtn_if->else_type == vtn_branch_type_none) {
971            vtn_emit_cf_list(b, &vtn_if->else_body,
972                             switch_fall_var, &sw_break, handler);
973         } else {
974            vtn_emit_branch(b, vtn_if->else_type, switch_fall_var, &sw_break);
975         }
976
977         nir_pop_if(&b->nb, nif);
978
979         /* If we encountered a switch break somewhere inside of the if,
980          * then it would have been handled correctly by calling
981          * emit_cf_list or emit_branch for the interrior.  However, we
982          * need to predicate everything following on wether or not we're
983          * still going.
984          */
985         if (sw_break) {
986            *has_switch_break = true;
987            nir_push_if(&b->nb, nir_load_var(&b->nb, switch_fall_var));
988         }
989         break;
990      }
991
992      case vtn_cf_node_type_loop: {
993         struct vtn_loop *vtn_loop = (struct vtn_loop *)node;
994
995         nir_loop *loop = nir_push_loop(&b->nb);
996         loop->control = vtn_loop_control(b, vtn_loop);
997
998         vtn_emit_cf_list(b, &vtn_loop->body, NULL, NULL, handler);
999
1000         if (!list_empty(&vtn_loop->cont_body)) {
1001            /* If we have a non-trivial continue body then we need to put
1002             * it at the beginning of the loop with a flag to ensure that
1003             * it doesn't get executed in the first iteration.
1004             */
1005            nir_variable *do_cont =
1006               nir_local_variable_create(b->nb.impl, glsl_bool_type(), "cont");
1007
1008            b->nb.cursor = nir_before_cf_node(&loop->cf_node);
1009            nir_store_var(&b->nb, do_cont, nir_imm_false(&b->nb), 1);
1010
1011            b->nb.cursor = nir_before_cf_list(&loop->body);
1012
1013            nir_if *cont_if =
1014               nir_push_if(&b->nb, nir_load_var(&b->nb, do_cont));
1015
1016            vtn_emit_cf_list(b, &vtn_loop->cont_body, NULL, NULL, handler);
1017
1018            nir_pop_if(&b->nb, cont_if);
1019
1020            nir_store_var(&b->nb, do_cont, nir_imm_true(&b->nb), 1);
1021
1022            b->has_loop_continue = true;
1023         }
1024
1025         nir_pop_loop(&b->nb, loop);
1026         break;
1027      }
1028
1029      case vtn_cf_node_type_switch: {
1030         struct vtn_switch *vtn_switch = (struct vtn_switch *)node;
1031
1032         /* First, we create a variable to keep track of whether or not the
1033          * switch is still going at any given point.  Any switch breaks
1034          * will set this variable to false.
1035          */
1036         nir_variable *fall_var =
1037            nir_local_variable_create(b->nb.impl, glsl_bool_type(), "fall");
1038         nir_store_var(&b->nb, fall_var, nir_imm_false(&b->nb), 1);
1039
1040         nir_ssa_def *sel = vtn_ssa_value(b, vtn_switch->selector)->def;
1041
1042         /* Now we can walk the list of cases and actually emit code */
1043         list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
1044            /* Figure out the condition */
1045            nir_ssa_def *cond =
1046               vtn_switch_case_condition(b, vtn_switch, sel, cse);
1047            /* Take fallthrough into account */
1048            cond = nir_ior(&b->nb, cond, nir_load_var(&b->nb, fall_var));
1049
1050            nir_if *case_if = nir_push_if(&b->nb, cond);
1051
1052            bool has_break = false;
1053            nir_store_var(&b->nb, fall_var, nir_imm_true(&b->nb), 1);
1054            vtn_emit_cf_list(b, &cse->body, fall_var, &has_break, handler);
1055            (void)has_break; /* We don't care */
1056
1057            nir_pop_if(&b->nb, case_if);
1058         }
1059
1060         break;
1061      }
1062
1063      default:
1064         vtn_fail("Invalid CF node type");
1065      }
1066   }
1067}
1068
1069void
1070vtn_function_emit(struct vtn_builder *b, struct vtn_function *func,
1071                  vtn_instruction_handler instruction_handler)
1072{
1073   nir_builder_init(&b->nb, func->impl);
1074   b->func = func;
1075   b->nb.cursor = nir_after_cf_list(&func->impl->body);
1076   b->nb.exact = b->exact;
1077   b->has_loop_continue = false;
1078   b->phi_table = _mesa_pointer_hash_table_create(b);
1079
1080   vtn_emit_cf_list(b, &func->body, NULL, NULL, instruction_handler);
1081
1082   vtn_foreach_instruction(b, func->start_block->label, func->end,
1083                           vtn_handle_phi_second_pass);
1084
1085   nir_rematerialize_derefs_in_use_blocks_impl(func->impl);
1086
1087   /* Continue blocks for loops get inserted before the body of the loop
1088    * but instructions in the continue may use SSA defs in the loop body.
1089    * Therefore, we need to repair SSA to insert the needed phi nodes.
1090    */
1091   if (b->has_loop_continue)
1092      nir_repair_ssa_impl(func->impl);
1093
1094   func->emitted = true;
1095}
1096