vtn_cfg.c revision 01e04c3f
1/*
2 * Copyright © 2015 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "vtn_private.h"
25#include "nir/nir_vla.h"
26
27static struct vtn_pointer *
28vtn_load_param_pointer(struct vtn_builder *b,
29                       struct vtn_type *param_type,
30                       uint32_t param_idx)
31{
32   struct vtn_type *ptr_type = param_type;
33   if (param_type->base_type != vtn_base_type_pointer) {
34      assert(param_type->base_type == vtn_base_type_image ||
35             param_type->base_type == vtn_base_type_sampler);
36      ptr_type = rzalloc(b, struct vtn_type);
37      ptr_type->base_type = vtn_base_type_pointer;
38      ptr_type->deref = param_type;
39      ptr_type->storage_class = SpvStorageClassUniformConstant;
40   }
41
42   return vtn_pointer_from_ssa(b, nir_load_param(&b->nb, param_idx), ptr_type);
43}
44
45static unsigned
46vtn_type_count_function_params(struct vtn_type *type)
47{
48   switch (type->base_type) {
49   case vtn_base_type_array:
50   case vtn_base_type_matrix:
51      return type->length * vtn_type_count_function_params(type->array_element);
52
53   case vtn_base_type_struct: {
54      unsigned count = 0;
55      for (unsigned i = 0; i < type->length; i++)
56         count += vtn_type_count_function_params(type->members[i]);
57      return count;
58   }
59
60   case vtn_base_type_sampled_image:
61      return 2;
62
63   default:
64      return 1;
65   }
66}
67
68static void
69vtn_type_add_to_function_params(struct vtn_type *type,
70                                nir_function *func,
71                                unsigned *param_idx)
72{
73   static const nir_parameter nir_deref_param = {
74      .num_components = 1,
75      .bit_size = 32,
76   };
77
78   switch (type->base_type) {
79   case vtn_base_type_array:
80   case vtn_base_type_matrix:
81      for (unsigned i = 0; i < type->length; i++)
82         vtn_type_add_to_function_params(type->array_element, func, param_idx);
83      break;
84
85   case vtn_base_type_struct:
86      for (unsigned i = 0; i < type->length; i++)
87         vtn_type_add_to_function_params(type->members[i], func, param_idx);
88      break;
89
90   case vtn_base_type_sampled_image:
91      func->params[(*param_idx)++] = nir_deref_param;
92      func->params[(*param_idx)++] = nir_deref_param;
93      break;
94
95   case vtn_base_type_image:
96   case vtn_base_type_sampler:
97      func->params[(*param_idx)++] = nir_deref_param;
98      break;
99
100   case vtn_base_type_pointer:
101      if (type->type) {
102         func->params[(*param_idx)++] = (nir_parameter) {
103            .num_components = glsl_get_vector_elements(type->type),
104            .bit_size = glsl_get_bit_size(type->type),
105         };
106      } else {
107         func->params[(*param_idx)++] = nir_deref_param;
108      }
109      break;
110
111   default:
112      func->params[(*param_idx)++] = (nir_parameter) {
113         .num_components = glsl_get_vector_elements(type->type),
114         .bit_size = glsl_get_bit_size(type->type),
115      };
116   }
117}
118
119static void
120vtn_ssa_value_add_to_call_params(struct vtn_builder *b,
121                                 struct vtn_ssa_value *value,
122                                 struct vtn_type *type,
123                                 nir_call_instr *call,
124                                 unsigned *param_idx)
125{
126   switch (type->base_type) {
127   case vtn_base_type_array:
128   case vtn_base_type_matrix:
129      for (unsigned i = 0; i < type->length; i++) {
130         vtn_ssa_value_add_to_call_params(b, value->elems[i],
131                                          type->array_element,
132                                          call, param_idx);
133      }
134      break;
135
136   case vtn_base_type_struct:
137      for (unsigned i = 0; i < type->length; i++) {
138         vtn_ssa_value_add_to_call_params(b, value->elems[i],
139                                          type->members[i],
140                                          call, param_idx);
141      }
142      break;
143
144   default:
145      call->params[(*param_idx)++] = nir_src_for_ssa(value->def);
146      break;
147   }
148}
149
150static void
151vtn_ssa_value_load_function_param(struct vtn_builder *b,
152                                  struct vtn_ssa_value *value,
153                                  struct vtn_type *type,
154                                  unsigned *param_idx)
155{
156   switch (type->base_type) {
157   case vtn_base_type_array:
158   case vtn_base_type_matrix:
159      for (unsigned i = 0; i < type->length; i++) {
160         vtn_ssa_value_load_function_param(b, value->elems[i],
161                                           type->array_element, param_idx);
162      }
163      break;
164
165   case vtn_base_type_struct:
166      for (unsigned i = 0; i < type->length; i++) {
167         vtn_ssa_value_load_function_param(b, value->elems[i],
168                                           type->members[i], param_idx);
169      }
170      break;
171
172   default:
173      value->def = nir_load_param(&b->nb, (*param_idx)++);
174      break;
175   }
176}
177
178void
179vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
180                         const uint32_t *w, unsigned count)
181{
182   struct vtn_type *res_type = vtn_value(b, w[1], vtn_value_type_type)->type;
183   struct vtn_function *vtn_callee =
184      vtn_value(b, w[3], vtn_value_type_function)->func;
185   struct nir_function *callee = vtn_callee->impl->function;
186
187   vtn_callee->referenced = true;
188
189   nir_call_instr *call = nir_call_instr_create(b->nb.shader, callee);
190
191   unsigned param_idx = 0;
192
193   nir_deref_instr *ret_deref = NULL;
194   struct vtn_type *ret_type = vtn_callee->type->return_type;
195   if (ret_type->base_type != vtn_base_type_void) {
196      nir_variable *ret_tmp =
197         nir_local_variable_create(b->nb.impl, ret_type->type, "return_tmp");
198      ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
199      call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
200   }
201
202   for (unsigned i = 0; i < vtn_callee->type->length; i++) {
203      struct vtn_type *arg_type = vtn_callee->type->params[i];
204      unsigned arg_id = w[4 + i];
205
206      if (arg_type->base_type == vtn_base_type_sampled_image) {
207         struct vtn_sampled_image *sampled_image =
208            vtn_value(b, arg_id, vtn_value_type_sampled_image)->sampled_image;
209
210         call->params[param_idx++] =
211            nir_src_for_ssa(&sampled_image->image->deref->dest.ssa);
212         call->params[param_idx++] =
213            nir_src_for_ssa(&sampled_image->sampler->deref->dest.ssa);
214      } else if (arg_type->base_type == vtn_base_type_pointer ||
215                 arg_type->base_type == vtn_base_type_image ||
216                 arg_type->base_type == vtn_base_type_sampler) {
217         struct vtn_pointer *pointer =
218            vtn_value(b, arg_id, vtn_value_type_pointer)->pointer;
219         call->params[param_idx++] =
220            nir_src_for_ssa(vtn_pointer_to_ssa(b, pointer));
221      } else {
222         vtn_ssa_value_add_to_call_params(b, vtn_ssa_value(b, arg_id),
223                                          arg_type, call, &param_idx);
224      }
225   }
226   assert(param_idx == call->num_params);
227
228   nir_builder_instr_insert(&b->nb, &call->instr);
229
230   if (ret_type->base_type == vtn_base_type_void) {
231      vtn_push_value(b, w[2], vtn_value_type_undef);
232   } else {
233      vtn_push_ssa(b, w[2], res_type, vtn_local_load(b, ret_deref));
234   }
235}
236
237static bool
238vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
239                                   const uint32_t *w, unsigned count)
240{
241   switch (opcode) {
242   case SpvOpFunction: {
243      vtn_assert(b->func == NULL);
244      b->func = rzalloc(b, struct vtn_function);
245
246      list_inithead(&b->func->body);
247      b->func->control = w[3];
248
249      MAYBE_UNUSED const struct glsl_type *result_type =
250         vtn_value(b, w[1], vtn_value_type_type)->type->type;
251      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
252      val->func = b->func;
253
254      b->func->type = vtn_value(b, w[4], vtn_value_type_type)->type;
255      const struct vtn_type *func_type = b->func->type;
256
257      vtn_assert(func_type->return_type->type == result_type);
258
259      nir_function *func =
260         nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
261
262      unsigned num_params = 0;
263      for (unsigned i = 0; i < func_type->length; i++)
264         num_params += vtn_type_count_function_params(func_type->params[i]);
265
266      /* Add one parameter for the function return value */
267      if (func_type->return_type->base_type != vtn_base_type_void)
268         num_params++;
269
270      func->num_params = num_params;
271      func->params = ralloc_array(b->shader, nir_parameter, num_params);
272
273      unsigned idx = 0;
274      if (func_type->return_type->base_type != vtn_base_type_void) {
275         /* The return value is a regular pointer */
276         func->params[idx++] = (nir_parameter) {
277            .num_components = 1, .bit_size = 32,
278         };
279      }
280
281      for (unsigned i = 0; i < func_type->length; i++)
282         vtn_type_add_to_function_params(func_type->params[i], func, &idx);
283      assert(idx == num_params);
284
285      b->func->impl = nir_function_impl_create(func);
286      nir_builder_init(&b->nb, func->impl);
287      b->nb.cursor = nir_before_cf_list(&b->func->impl->body);
288
289      b->func_param_idx = 0;
290
291      /* The return value is the first parameter */
292      if (func_type->return_type->base_type != vtn_base_type_void)
293         b->func_param_idx++;
294      break;
295   }
296
297   case SpvOpFunctionEnd:
298      b->func->end = w;
299      b->func = NULL;
300      break;
301
302   case SpvOpFunctionParameter: {
303      struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
304
305      vtn_assert(b->func_param_idx < b->func->impl->function->num_params);
306
307      if (type->base_type == vtn_base_type_sampled_image) {
308         /* Sampled images are actually two parameters.  The first is the
309          * image and the second is the sampler.
310          */
311         struct vtn_value *val =
312            vtn_push_value(b, w[2], vtn_value_type_sampled_image);
313
314         val->sampled_image = ralloc(b, struct vtn_sampled_image);
315         val->sampled_image->type = type;
316
317         struct vtn_type *sampler_type = rzalloc(b, struct vtn_type);
318         sampler_type->base_type = vtn_base_type_sampler;
319         sampler_type->type = glsl_bare_sampler_type();
320
321         val->sampled_image->image =
322            vtn_load_param_pointer(b, type, b->func_param_idx++);
323         val->sampled_image->sampler =
324            vtn_load_param_pointer(b, sampler_type, b->func_param_idx++);
325      } else if (type->base_type == vtn_base_type_pointer &&
326                 type->type != NULL) {
327         /* This is a pointer with an actual storage type */
328         struct vtn_value *val =
329            vtn_push_value(b, w[2], vtn_value_type_pointer);
330         nir_ssa_def *ssa_ptr = nir_load_param(&b->nb, b->func_param_idx++);
331         val->pointer = vtn_pointer_from_ssa(b, ssa_ptr, type);
332      } else if (type->base_type == vtn_base_type_pointer ||
333                 type->base_type == vtn_base_type_image ||
334                 type->base_type == vtn_base_type_sampler) {
335         struct vtn_value *val =
336            vtn_push_value(b, w[2], vtn_value_type_pointer);
337         val->pointer =
338            vtn_load_param_pointer(b, type, b->func_param_idx++);
339      } else {
340         /* We're a regular SSA value. */
341         struct vtn_ssa_value *value = vtn_create_ssa_value(b, type->type);
342         vtn_ssa_value_load_function_param(b, value, type, &b->func_param_idx);
343         vtn_push_ssa(b, w[2], type, value);
344      }
345      break;
346   }
347
348   case SpvOpLabel: {
349      vtn_assert(b->block == NULL);
350      b->block = rzalloc(b, struct vtn_block);
351      b->block->node.type = vtn_cf_node_type_block;
352      b->block->label = w;
353      vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
354
355      if (b->func->start_block == NULL) {
356         /* This is the first block encountered for this function.  In this
357          * case, we set the start block and add it to the list of
358          * implemented functions that we'll walk later.
359          */
360         b->func->start_block = b->block;
361         exec_list_push_tail(&b->functions, &b->func->node);
362      }
363      break;
364   }
365
366   case SpvOpSelectionMerge:
367   case SpvOpLoopMerge:
368      vtn_assert(b->block && b->block->merge == NULL);
369      b->block->merge = w;
370      break;
371
372   case SpvOpBranch:
373   case SpvOpBranchConditional:
374   case SpvOpSwitch:
375   case SpvOpKill:
376   case SpvOpReturn:
377   case SpvOpReturnValue:
378   case SpvOpUnreachable:
379      vtn_assert(b->block && b->block->branch == NULL);
380      b->block->branch = w;
381      b->block = NULL;
382      break;
383
384   default:
385      /* Continue on as per normal */
386      return true;
387   }
388
389   return true;
390}
391
392static void
393vtn_add_case(struct vtn_builder *b, struct vtn_switch *swtch,
394             struct vtn_block *break_block,
395             uint32_t block_id, uint64_t val, bool is_default)
396{
397   struct vtn_block *case_block =
398      vtn_value(b, block_id, vtn_value_type_block)->block;
399
400   /* Don't create dummy cases that just break */
401   if (case_block == break_block)
402      return;
403
404   if (case_block->switch_case == NULL) {
405      struct vtn_case *c = ralloc(b, struct vtn_case);
406
407      list_inithead(&c->body);
408      c->start_block = case_block;
409      c->fallthrough = NULL;
410      util_dynarray_init(&c->values, b);
411      c->is_default = false;
412      c->visited = false;
413
414      list_addtail(&c->link, &swtch->cases);
415
416      case_block->switch_case = c;
417   }
418
419   if (is_default) {
420      case_block->switch_case->is_default = true;
421   } else {
422      util_dynarray_append(&case_block->switch_case->values, uint64_t, val);
423   }
424}
425
426/* This function performs a depth-first search of the cases and puts them
427 * in fall-through order.
428 */
429static void
430vtn_order_case(struct vtn_switch *swtch, struct vtn_case *cse)
431{
432   if (cse->visited)
433      return;
434
435   cse->visited = true;
436
437   list_del(&cse->link);
438
439   if (cse->fallthrough) {
440      vtn_order_case(swtch, cse->fallthrough);
441
442      /* If we have a fall-through, place this case right before the case it
443       * falls through to.  This ensures that fallthroughs come one after
444       * the other.  These two can never get separated because that would
445       * imply something else falling through to the same case.  Also, this
446       * can't break ordering because the DFS ensures that this case is
447       * visited before anything that falls through to it.
448       */
449      list_addtail(&cse->link, &cse->fallthrough->link);
450   } else {
451      list_add(&cse->link, &swtch->cases);
452   }
453}
454
455static enum vtn_branch_type
456vtn_get_branch_type(struct vtn_builder *b,
457                    struct vtn_block *block,
458                    struct vtn_case *swcase, struct vtn_block *switch_break,
459                    struct vtn_block *loop_break, struct vtn_block *loop_cont)
460{
461   if (block->switch_case) {
462      /* This branch is actually a fallthrough */
463      vtn_assert(swcase->fallthrough == NULL ||
464                 swcase->fallthrough == block->switch_case);
465      swcase->fallthrough = block->switch_case;
466      return vtn_branch_type_switch_fallthrough;
467   } else if (block == loop_break) {
468      return vtn_branch_type_loop_break;
469   } else if (block == loop_cont) {
470      return vtn_branch_type_loop_continue;
471   } else if (block == switch_break) {
472      return vtn_branch_type_switch_break;
473   } else {
474      return vtn_branch_type_none;
475   }
476}
477
478static void
479vtn_cfg_walk_blocks(struct vtn_builder *b, struct list_head *cf_list,
480                    struct vtn_block *start, struct vtn_case *switch_case,
481                    struct vtn_block *switch_break,
482                    struct vtn_block *loop_break, struct vtn_block *loop_cont,
483                    struct vtn_block *end)
484{
485   struct vtn_block *block = start;
486   while (block != end) {
487      if (block->merge && (*block->merge & SpvOpCodeMask) == SpvOpLoopMerge &&
488          !block->loop) {
489         struct vtn_loop *loop = ralloc(b, struct vtn_loop);
490
491         loop->node.type = vtn_cf_node_type_loop;
492         list_inithead(&loop->body);
493         list_inithead(&loop->cont_body);
494         loop->control = block->merge[3];
495
496         list_addtail(&loop->node.link, cf_list);
497         block->loop = loop;
498
499         struct vtn_block *new_loop_break =
500            vtn_value(b, block->merge[1], vtn_value_type_block)->block;
501         struct vtn_block *new_loop_cont =
502            vtn_value(b, block->merge[2], vtn_value_type_block)->block;
503
504         /* Note: This recursive call will start with the current block as
505          * its start block.  If we weren't careful, we would get here
506          * again and end up in infinite recursion.  This is why we set
507          * block->loop above and check for it before creating one.  This
508          * way, we only create the loop once and the second call that
509          * tries to handle this loop goes to the cases below and gets
510          * handled as a regular block.
511          *
512          * Note: When we make the recursive walk calls, we pass NULL for
513          * the switch break since you have to break out of the loop first.
514          * We do, however, still pass the current switch case because it's
515          * possible that the merge block for the loop is the start of
516          * another case.
517          */
518         vtn_cfg_walk_blocks(b, &loop->body, block, switch_case, NULL,
519                             new_loop_break, new_loop_cont, NULL );
520         vtn_cfg_walk_blocks(b, &loop->cont_body, new_loop_cont, NULL, NULL,
521                             new_loop_break, NULL, block);
522
523         enum vtn_branch_type branch_type =
524            vtn_get_branch_type(b, new_loop_break, switch_case, switch_break,
525                                loop_break, loop_cont);
526
527         if (branch_type != vtn_branch_type_none) {
528            /* Stop walking through the CFG when this inner loop's break block
529             * ends up as the same block as the outer loop's continue block
530             * because we are already going to visit it.
531             */
532            vtn_assert(branch_type == vtn_branch_type_loop_continue);
533            return;
534         }
535
536         block = new_loop_break;
537         continue;
538      }
539
540      vtn_assert(block->node.link.next == NULL);
541      list_addtail(&block->node.link, cf_list);
542
543      switch (*block->branch & SpvOpCodeMask) {
544      case SpvOpBranch: {
545         struct vtn_block *branch_block =
546            vtn_value(b, block->branch[1], vtn_value_type_block)->block;
547
548         block->branch_type = vtn_get_branch_type(b, branch_block,
549                                                  switch_case, switch_break,
550                                                  loop_break, loop_cont);
551
552         if (block->branch_type != vtn_branch_type_none)
553            return;
554
555         block = branch_block;
556         continue;
557      }
558
559      case SpvOpReturn:
560      case SpvOpReturnValue:
561         block->branch_type = vtn_branch_type_return;
562         return;
563
564      case SpvOpKill:
565         block->branch_type = vtn_branch_type_discard;
566         return;
567
568      case SpvOpBranchConditional: {
569         struct vtn_block *then_block =
570            vtn_value(b, block->branch[2], vtn_value_type_block)->block;
571         struct vtn_block *else_block =
572            vtn_value(b, block->branch[3], vtn_value_type_block)->block;
573
574         struct vtn_if *if_stmt = ralloc(b, struct vtn_if);
575
576         if_stmt->node.type = vtn_cf_node_type_if;
577         if_stmt->condition = block->branch[1];
578         list_inithead(&if_stmt->then_body);
579         list_inithead(&if_stmt->else_body);
580
581         list_addtail(&if_stmt->node.link, cf_list);
582
583         if (block->merge &&
584             (*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge) {
585            if_stmt->control = block->merge[2];
586         }
587
588         if_stmt->then_type = vtn_get_branch_type(b, then_block,
589                                                  switch_case, switch_break,
590                                                  loop_break, loop_cont);
591         if_stmt->else_type = vtn_get_branch_type(b, else_block,
592                                                  switch_case, switch_break,
593                                                  loop_break, loop_cont);
594
595         if (then_block == else_block) {
596            block->branch_type = if_stmt->then_type;
597            if (block->branch_type == vtn_branch_type_none) {
598               block = then_block;
599               continue;
600            } else {
601               return;
602            }
603         } else if (if_stmt->then_type == vtn_branch_type_none &&
604                    if_stmt->else_type == vtn_branch_type_none) {
605            /* Neither side of the if is something we can short-circuit. */
606            vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
607            struct vtn_block *merge_block =
608               vtn_value(b, block->merge[1], vtn_value_type_block)->block;
609
610            vtn_cfg_walk_blocks(b, &if_stmt->then_body, then_block,
611                                switch_case, switch_break,
612                                loop_break, loop_cont, merge_block);
613            vtn_cfg_walk_blocks(b, &if_stmt->else_body, else_block,
614                                switch_case, switch_break,
615                                loop_break, loop_cont, merge_block);
616
617            enum vtn_branch_type merge_type =
618               vtn_get_branch_type(b, merge_block, switch_case, switch_break,
619                                   loop_break, loop_cont);
620            if (merge_type == vtn_branch_type_none) {
621               block = merge_block;
622               continue;
623            } else {
624               return;
625            }
626         } else if (if_stmt->then_type != vtn_branch_type_none &&
627                    if_stmt->else_type != vtn_branch_type_none) {
628            /* Both sides were short-circuited.  We're done here. */
629            return;
630         } else {
631            /* Exeactly one side of the branch could be short-circuited.
632             * We set the branch up as a predicated break/continue and we
633             * continue on with the other side as if it were what comes
634             * after the if.
635             */
636            if (if_stmt->then_type == vtn_branch_type_none) {
637               block = then_block;
638            } else {
639               block = else_block;
640            }
641            continue;
642         }
643         vtn_fail("Should have returned or continued");
644      }
645
646      case SpvOpSwitch: {
647         vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
648         struct vtn_block *break_block =
649            vtn_value(b, block->merge[1], vtn_value_type_block)->block;
650
651         struct vtn_switch *swtch = ralloc(b, struct vtn_switch);
652
653         swtch->node.type = vtn_cf_node_type_switch;
654         swtch->selector = block->branch[1];
655         list_inithead(&swtch->cases);
656
657         list_addtail(&swtch->node.link, cf_list);
658
659         /* First, we go through and record all of the cases. */
660         const uint32_t *branch_end =
661            block->branch + (block->branch[0] >> SpvWordCountShift);
662
663         struct vtn_value *cond_val = vtn_untyped_value(b, block->branch[1]);
664         vtn_fail_if(!cond_val->type ||
665                     cond_val->type->base_type != vtn_base_type_scalar,
666                     "Selector of OpSelect must have a type of OpTypeInt");
667
668         nir_alu_type cond_type =
669            nir_get_nir_type_for_glsl_type(cond_val->type->type);
670         vtn_fail_if(nir_alu_type_get_base_type(cond_type) != nir_type_int &&
671                     nir_alu_type_get_base_type(cond_type) != nir_type_uint,
672                     "Selector of OpSelect must have a type of OpTypeInt");
673
674         bool is_default = true;
675         const unsigned bitsize = nir_alu_type_get_type_size(cond_type);
676         for (const uint32_t *w = block->branch + 2; w < branch_end;) {
677            uint64_t literal = 0;
678            if (!is_default) {
679               if (bitsize <= 32) {
680                  literal = *(w++);
681               } else {
682                  assert(bitsize == 64);
683                  literal = vtn_u64_literal(w);
684                  w += 2;
685               }
686            }
687
688            uint32_t block_id = *(w++);
689
690            vtn_add_case(b, swtch, break_block, block_id, literal, is_default);
691            is_default = false;
692         }
693
694         /* Now, we go through and walk the blocks.  While we walk through
695          * the blocks, we also gather the much-needed fall-through
696          * information.
697          */
698         list_for_each_entry(struct vtn_case, cse, &swtch->cases, link) {
699            vtn_assert(cse->start_block != break_block);
700            vtn_cfg_walk_blocks(b, &cse->body, cse->start_block, cse,
701                                break_block, loop_break, loop_cont, NULL);
702         }
703
704         /* Finally, we walk over all of the cases one more time and put
705          * them in fall-through order.
706          */
707         for (const uint32_t *w = block->branch + 2; w < branch_end;) {
708            struct vtn_block *case_block =
709               vtn_value(b, *w, vtn_value_type_block)->block;
710
711            if (bitsize <= 32) {
712               w += 2;
713            } else {
714               assert(bitsize == 64);
715               w += 3;
716            }
717
718            if (case_block == break_block)
719               continue;
720
721            vtn_assert(case_block->switch_case);
722
723            vtn_order_case(swtch, case_block->switch_case);
724         }
725
726         enum vtn_branch_type branch_type =
727            vtn_get_branch_type(b, break_block, switch_case, NULL,
728                                loop_break, loop_cont);
729
730         if (branch_type != vtn_branch_type_none) {
731            /* It is possible that the break is actually the continue block
732             * for the containing loop.  In this case, we need to bail and let
733             * the loop parsing code handle the continue properly.
734             */
735            vtn_assert(branch_type == vtn_branch_type_loop_continue);
736            return;
737         }
738
739         block = break_block;
740         continue;
741      }
742
743      case SpvOpUnreachable:
744         return;
745
746      default:
747         vtn_fail("Unhandled opcode");
748      }
749   }
750}
751
752void
753vtn_build_cfg(struct vtn_builder *b, const uint32_t *words, const uint32_t *end)
754{
755   vtn_foreach_instruction(b, words, end,
756                           vtn_cfg_handle_prepass_instruction);
757
758   foreach_list_typed(struct vtn_function, func, node, &b->functions) {
759      vtn_cfg_walk_blocks(b, &func->body, func->start_block,
760                          NULL, NULL, NULL, NULL, NULL);
761   }
762}
763
764static bool
765vtn_handle_phis_first_pass(struct vtn_builder *b, SpvOp opcode,
766                           const uint32_t *w, unsigned count)
767{
768   if (opcode == SpvOpLabel)
769      return true; /* Nothing to do */
770
771   /* If this isn't a phi node, stop. */
772   if (opcode != SpvOpPhi)
773      return false;
774
775   /* For handling phi nodes, we do a poor-man's out-of-ssa on the spot.
776    * For each phi, we create a variable with the appropreate type and
777    * do a load from that variable.  Then, in a second pass, we add
778    * stores to that variable to each of the predecessor blocks.
779    *
780    * We could do something more intelligent here.  However, in order to
781    * handle loops and things properly, we really need dominance
782    * information.  It would end up basically being the into-SSA
783    * algorithm all over again.  It's easier if we just let
784    * lower_vars_to_ssa do that for us instead of repeating it here.
785    */
786   struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
787   nir_variable *phi_var =
788      nir_local_variable_create(b->nb.impl, type->type, "phi");
789   _mesa_hash_table_insert(b->phi_table, w, phi_var);
790
791   vtn_push_ssa(b, w[2], type,
792                vtn_local_load(b, nir_build_deref_var(&b->nb, phi_var)));
793
794   return true;
795}
796
797static bool
798vtn_handle_phi_second_pass(struct vtn_builder *b, SpvOp opcode,
799                           const uint32_t *w, unsigned count)
800{
801   if (opcode != SpvOpPhi)
802      return true;
803
804   struct hash_entry *phi_entry = _mesa_hash_table_search(b->phi_table, w);
805   vtn_assert(phi_entry);
806   nir_variable *phi_var = phi_entry->data;
807
808   for (unsigned i = 3; i < count; i += 2) {
809      struct vtn_block *pred =
810         vtn_value(b, w[i + 1], vtn_value_type_block)->block;
811
812      b->nb.cursor = nir_after_instr(&pred->end_nop->instr);
813
814      struct vtn_ssa_value *src = vtn_ssa_value(b, w[i]);
815
816      vtn_local_store(b, src, nir_build_deref_var(&b->nb, phi_var));
817   }
818
819   return true;
820}
821
822static void
823vtn_emit_branch(struct vtn_builder *b, enum vtn_branch_type branch_type,
824                nir_variable *switch_fall_var, bool *has_switch_break)
825{
826   switch (branch_type) {
827   case vtn_branch_type_switch_break:
828      nir_store_var(&b->nb, switch_fall_var, nir_imm_false(&b->nb), 1);
829      *has_switch_break = true;
830      break;
831   case vtn_branch_type_switch_fallthrough:
832      break; /* Nothing to do */
833   case vtn_branch_type_loop_break:
834      nir_jump(&b->nb, nir_jump_break);
835      break;
836   case vtn_branch_type_loop_continue:
837      nir_jump(&b->nb, nir_jump_continue);
838      break;
839   case vtn_branch_type_return:
840      nir_jump(&b->nb, nir_jump_return);
841      break;
842   case vtn_branch_type_discard: {
843      nir_intrinsic_instr *discard =
844         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_discard);
845      nir_builder_instr_insert(&b->nb, &discard->instr);
846      break;
847   }
848   default:
849      vtn_fail("Invalid branch type");
850   }
851}
852
853static void
854vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
855                 nir_variable *switch_fall_var, bool *has_switch_break,
856                 vtn_instruction_handler handler)
857{
858   list_for_each_entry(struct vtn_cf_node, node, cf_list, link) {
859      switch (node->type) {
860      case vtn_cf_node_type_block: {
861         struct vtn_block *block = (struct vtn_block *)node;
862
863         const uint32_t *block_start = block->label;
864         const uint32_t *block_end = block->merge ? block->merge :
865                                                    block->branch;
866
867         block_start = vtn_foreach_instruction(b, block_start, block_end,
868                                               vtn_handle_phis_first_pass);
869
870         vtn_foreach_instruction(b, block_start, block_end, handler);
871
872         block->end_nop = nir_intrinsic_instr_create(b->nb.shader,
873                                                     nir_intrinsic_nop);
874         nir_builder_instr_insert(&b->nb, &block->end_nop->instr);
875
876         if ((*block->branch & SpvOpCodeMask) == SpvOpReturnValue) {
877            vtn_fail_if(b->func->type->return_type->base_type ==
878                        vtn_base_type_void,
879                        "Return with a value from a function returning void");
880            struct vtn_ssa_value *src = vtn_ssa_value(b, block->branch[1]);
881            nir_deref_instr *ret_deref =
882               nir_build_deref_cast(&b->nb, nir_load_param(&b->nb, 0),
883                                    nir_var_local, src->type);
884            vtn_local_store(b, src, ret_deref);
885         }
886
887         if (block->branch_type != vtn_branch_type_none) {
888            vtn_emit_branch(b, block->branch_type,
889                            switch_fall_var, has_switch_break);
890         }
891
892         break;
893      }
894
895      case vtn_cf_node_type_if: {
896         struct vtn_if *vtn_if = (struct vtn_if *)node;
897         bool sw_break = false;
898
899         nir_if *nif =
900            nir_push_if(&b->nb, vtn_ssa_value(b, vtn_if->condition)->def);
901         if (vtn_if->then_type == vtn_branch_type_none) {
902            vtn_emit_cf_list(b, &vtn_if->then_body,
903                             switch_fall_var, &sw_break, handler);
904         } else {
905            vtn_emit_branch(b, vtn_if->then_type, switch_fall_var, &sw_break);
906         }
907
908         nir_push_else(&b->nb, nif);
909         if (vtn_if->else_type == vtn_branch_type_none) {
910            vtn_emit_cf_list(b, &vtn_if->else_body,
911                             switch_fall_var, &sw_break, handler);
912         } else {
913            vtn_emit_branch(b, vtn_if->else_type, switch_fall_var, &sw_break);
914         }
915
916         nir_pop_if(&b->nb, nif);
917
918         /* If we encountered a switch break somewhere inside of the if,
919          * then it would have been handled correctly by calling
920          * emit_cf_list or emit_branch for the interrior.  However, we
921          * need to predicate everything following on wether or not we're
922          * still going.
923          */
924         if (sw_break) {
925            *has_switch_break = true;
926            nir_push_if(&b->nb, nir_load_var(&b->nb, switch_fall_var));
927         }
928         break;
929      }
930
931      case vtn_cf_node_type_loop: {
932         struct vtn_loop *vtn_loop = (struct vtn_loop *)node;
933
934         nir_loop *loop = nir_push_loop(&b->nb);
935         vtn_emit_cf_list(b, &vtn_loop->body, NULL, NULL, handler);
936
937         if (!list_empty(&vtn_loop->cont_body)) {
938            /* If we have a non-trivial continue body then we need to put
939             * it at the beginning of the loop with a flag to ensure that
940             * it doesn't get executed in the first iteration.
941             */
942            nir_variable *do_cont =
943               nir_local_variable_create(b->nb.impl, glsl_bool_type(), "cont");
944
945            b->nb.cursor = nir_before_cf_node(&loop->cf_node);
946            nir_store_var(&b->nb, do_cont, nir_imm_false(&b->nb), 1);
947
948            b->nb.cursor = nir_before_cf_list(&loop->body);
949
950            nir_if *cont_if =
951               nir_push_if(&b->nb, nir_load_var(&b->nb, do_cont));
952
953            vtn_emit_cf_list(b, &vtn_loop->cont_body, NULL, NULL, handler);
954
955            nir_pop_if(&b->nb, cont_if);
956
957            nir_store_var(&b->nb, do_cont, nir_imm_true(&b->nb), 1);
958
959            b->has_loop_continue = true;
960         }
961
962         nir_pop_loop(&b->nb, loop);
963         break;
964      }
965
966      case vtn_cf_node_type_switch: {
967         struct vtn_switch *vtn_switch = (struct vtn_switch *)node;
968
969         /* First, we create a variable to keep track of whether or not the
970          * switch is still going at any given point.  Any switch breaks
971          * will set this variable to false.
972          */
973         nir_variable *fall_var =
974            nir_local_variable_create(b->nb.impl, glsl_bool_type(), "fall");
975         nir_store_var(&b->nb, fall_var, nir_imm_false(&b->nb), 1);
976
977         /* Next, we gather up all of the conditions.  We have to do this
978          * up-front because we also need to build an "any" condition so
979          * that we can use !any for default.
980          */
981         const int num_cases = list_length(&vtn_switch->cases);
982         NIR_VLA(nir_ssa_def *, conditions, num_cases);
983
984         nir_ssa_def *sel = vtn_ssa_value(b, vtn_switch->selector)->def;
985         /* An accumulation of all conditions.  Used for the default */
986         nir_ssa_def *any = NULL;
987
988         int i = 0;
989         list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
990            if (cse->is_default) {
991               conditions[i++] = NULL;
992               continue;
993            }
994
995            nir_ssa_def *cond = NULL;
996            util_dynarray_foreach(&cse->values, uint64_t, val) {
997               nir_ssa_def *imm = nir_imm_intN_t(&b->nb, *val, sel->bit_size);
998               nir_ssa_def *is_val = nir_ieq(&b->nb, sel, imm);
999
1000               cond = cond ? nir_ior(&b->nb, cond, is_val) : is_val;
1001            }
1002
1003            any = any ? nir_ior(&b->nb, any, cond) : cond;
1004            conditions[i++] = cond;
1005         }
1006         vtn_assert(i == num_cases);
1007
1008         /* Now we can walk the list of cases and actually emit code */
1009         i = 0;
1010         list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
1011            /* Figure out the condition */
1012            nir_ssa_def *cond = conditions[i++];
1013            if (cse->is_default) {
1014               vtn_assert(cond == NULL);
1015               cond = nir_inot(&b->nb, any);
1016            }
1017            /* Take fallthrough into account */
1018            cond = nir_ior(&b->nb, cond, nir_load_var(&b->nb, fall_var));
1019
1020            nir_if *case_if = nir_push_if(&b->nb, cond);
1021
1022            bool has_break = false;
1023            nir_store_var(&b->nb, fall_var, nir_imm_true(&b->nb), 1);
1024            vtn_emit_cf_list(b, &cse->body, fall_var, &has_break, handler);
1025            (void)has_break; /* We don't care */
1026
1027            nir_pop_if(&b->nb, case_if);
1028         }
1029         vtn_assert(i == num_cases);
1030
1031         break;
1032      }
1033
1034      default:
1035         vtn_fail("Invalid CF node type");
1036      }
1037   }
1038}
1039
1040void
1041vtn_function_emit(struct vtn_builder *b, struct vtn_function *func,
1042                  vtn_instruction_handler instruction_handler)
1043{
1044   nir_builder_init(&b->nb, func->impl);
1045   b->func = func;
1046   b->nb.cursor = nir_after_cf_list(&func->impl->body);
1047   b->has_loop_continue = false;
1048   b->phi_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
1049                                          _mesa_key_pointer_equal);
1050
1051   vtn_emit_cf_list(b, &func->body, NULL, NULL, instruction_handler);
1052
1053   vtn_foreach_instruction(b, func->start_block->label, func->end,
1054                           vtn_handle_phi_second_pass);
1055
1056   /* Continue blocks for loops get inserted before the body of the loop
1057    * but instructions in the continue may use SSA defs in the loop body.
1058    * Therefore, we need to repair SSA to insert the needed phi nodes.
1059    */
1060   if (b->has_loop_continue)
1061      nir_repair_ssa_impl(func->impl);
1062
1063   func->emitted = true;
1064}
1065