1/*
2 * Copyright © 2018 Red Hat
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 *    Rob Clark (robdclark@gmail.com)
25 */
26
27#include "math.h"
28
29#include "nir/nir_builtin_builder.h"
30
31#include "vtn_private.h"
32#include "OpenCL.std.h"
33
34typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b, enum OpenCLstd opcode,
35                                    unsigned num_srcs, nir_ssa_def **srcs,
36                                    const struct glsl_type *dest_type);
37
38static void
39handle_instr(struct vtn_builder *b, enum OpenCLstd opcode, const uint32_t *w,
40             unsigned count, nir_handler handler)
41{
42   const struct glsl_type *dest_type =
43      vtn_value(b, w[1], vtn_value_type_type)->type->type;
44
45   unsigned num_srcs = count - 5;
46   nir_ssa_def *srcs[3] = { NULL };
47   vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
48   for (unsigned i = 0; i < num_srcs; i++) {
49      srcs[i] = vtn_ssa_value(b, w[i + 5])->def;
50   }
51
52   nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, dest_type);
53   if (result) {
54      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
55      val->ssa = vtn_create_ssa_value(b, dest_type);
56      val->ssa->def = result;
57   } else {
58      vtn_assert(dest_type == glsl_void_type());
59   }
60}
61
62static nir_op
63nir_alu_op_for_opencl_opcode(struct vtn_builder *b, enum OpenCLstd opcode)
64{
65   switch (opcode) {
66   case Fabs: return nir_op_fabs;
67   case SAbs: return nir_op_iabs;
68   case SAdd_sat: return nir_op_iadd_sat;
69   case UAdd_sat: return nir_op_uadd_sat;
70   case Ceil: return nir_op_fceil;
71   case Cos: return nir_op_fcos;
72   case Exp2: return nir_op_fexp2;
73   case Log2: return nir_op_flog2;
74   case Floor: return nir_op_ffloor;
75   case SHadd: return nir_op_ihadd;
76   case UHadd: return nir_op_uhadd;
77   case Fma: return nir_op_ffma;
78   case Fmax: return nir_op_fmax;
79   case SMax: return nir_op_imax;
80   case UMax: return nir_op_umax;
81   case Fmin: return nir_op_fmin;
82   case SMin: return nir_op_imin;
83   case UMin: return nir_op_umin;
84   case Fmod: return nir_op_fmod;
85   case Mix: return nir_op_flrp;
86   case SMul_hi: return nir_op_imul_high;
87   case UMul_hi: return nir_op_umul_high;
88   case Popcount: return nir_op_bit_count;
89   case Pow: return nir_op_fpow;
90   case Remainder: return nir_op_frem;
91   case SRhadd: return nir_op_irhadd;
92   case URhadd: return nir_op_urhadd;
93   case Rsqrt: return nir_op_frsq;
94   case Sign: return nir_op_fsign;
95   case Sin: return nir_op_fsin;
96   case Sqrt: return nir_op_fsqrt;
97   case SSub_sat: return nir_op_isub_sat;
98   case USub_sat: return nir_op_usub_sat;
99   case Trunc: return nir_op_ftrunc;
100   /* uhm... */
101   case UAbs: return nir_op_imov;
102   default:
103      vtn_fail("No NIR equivalent");
104   }
105}
106
107static nir_ssa_def *
108handle_alu(struct vtn_builder *b, enum OpenCLstd opcode, unsigned num_srcs,
109           nir_ssa_def **srcs, const struct glsl_type *dest_type)
110{
111   return nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode),
112                        srcs[0], srcs[1], srcs[2], NULL);
113}
114
115static nir_ssa_def *
116handle_special(struct vtn_builder *b, enum OpenCLstd opcode, unsigned num_srcs,
117               nir_ssa_def **srcs, const struct glsl_type *dest_type)
118{
119   nir_builder *nb = &b->nb;
120
121   switch (opcode) {
122   case SAbs_diff:
123      return nir_iabs_diff(nb, srcs[0], srcs[1]);
124   case UAbs_diff:
125      return nir_uabs_diff(nb, srcs[0], srcs[1]);
126   case Bitselect:
127      return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
128   case FClamp:
129      return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
130   case SClamp:
131      return nir_iclamp(nb, srcs[0], srcs[1], srcs[2]);
132   case UClamp:
133      return nir_uclamp(nb, srcs[0], srcs[1], srcs[2]);
134   case Copysign:
135      return nir_copysign(nb, srcs[0], srcs[1]);
136   case Cross:
137      if (glsl_get_components(dest_type) == 4)
138         return nir_cross4(nb, srcs[0], srcs[1]);
139      return nir_cross3(nb, srcs[0], srcs[1]);
140   case Degrees:
141      return nir_degrees(nb, srcs[0]);
142   case Fdim:
143      return nir_fdim(nb, srcs[0], srcs[1]);
144   case Distance:
145      return nir_distance(nb, srcs[0], srcs[1]);
146   case Fast_distance:
147      return nir_fast_distance(nb, srcs[0], srcs[1]);
148   case Fast_length:
149      return nir_fast_length(nb, srcs[0]);
150   case Fast_normalize:
151      return nir_fast_normalize(nb, srcs[0]);
152   case Length:
153      return nir_length(nb, srcs[0]);
154   case Mad:
155      return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
156   case Maxmag:
157      return nir_maxmag(nb, srcs[0], srcs[1]);
158   case Minmag:
159      return nir_minmag(nb, srcs[0], srcs[1]);
160   case Nan:
161      return nir_nan(nb, srcs[0]);
162   case Nextafter:
163      return nir_nextafter(nb, srcs[0], srcs[1]);
164   case Normalize:
165      return nir_normalize(nb, srcs[0]);
166   case Radians:
167      return nir_radians(nb, srcs[0]);
168   case Rotate:
169      return nir_rotate(nb, srcs[0], srcs[1]);
170   case Smoothstep:
171      return nir_smoothstep(nb, srcs[0], srcs[1], srcs[2]);
172   case Select:
173      return nir_select(nb, srcs[0], srcs[1], srcs[2]);
174   case Step:
175      return nir_sge(nb, srcs[1], srcs[0]);
176   case S_Upsample:
177   case U_Upsample:
178      return nir_upsample(nb, srcs[0], srcs[1]);
179   default:
180      vtn_fail("No NIR equivalent");
181      return NULL;
182   }
183}
184
185static void
186_handle_v_load_store(struct vtn_builder *b, enum OpenCLstd opcode,
187                     const uint32_t *w, unsigned count, bool load)
188{
189   struct vtn_type *type;
190   if (load)
191      type = vtn_value(b, w[1], vtn_value_type_type)->type;
192   else
193      type = vtn_untyped_value(b, w[5])->type;
194   unsigned a = load ? 0 : 1;
195
196   const struct glsl_type *dest_type = type->type;
197   unsigned components = glsl_get_vector_elements(dest_type);
198   unsigned stride = components * glsl_get_bit_size(dest_type) / 8;
199
200   nir_ssa_def *offset = vtn_ssa_value(b, w[5 + a])->def;
201   struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);
202
203   nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);
204
205   /* 1. cast to vec type with adjusted stride */
206   deref = nir_build_deref_cast(&b->nb, &deref->dest.ssa, deref->mode,
207                                dest_type, stride);
208   /* 2. deref ptr_as_array */
209   deref = nir_build_deref_ptr_as_array(&b->nb, deref, offset);
210
211   if (load) {
212      struct vtn_ssa_value *val = vtn_local_load(b, deref, p->type->access);
213      vtn_push_ssa(b, w[2], type, val);
214   } else {
215      struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
216      vtn_local_store(b, val, deref, p->type->access);
217   }
218}
219
220static void
221vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd opcode,
222                        const uint32_t *w, unsigned count)
223{
224   _handle_v_load_store(b, opcode, w, count, true);
225}
226
227static void
228vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd opcode,
229                         const uint32_t *w, unsigned count)
230{
231   _handle_v_load_store(b, opcode, w, count, false);
232}
233
234static nir_ssa_def *
235handle_printf(struct vtn_builder *b, enum OpenCLstd opcode, unsigned num_srcs,
236              nir_ssa_def **srcs, const struct glsl_type *dest_type)
237{
238   /* hahah, yeah, right.. */
239   return nir_imm_int(&b->nb, -1);
240}
241
242bool
243vtn_handle_opencl_instruction(struct vtn_builder *b, uint32_t ext_opcode,
244                              const uint32_t *w, unsigned count)
245{
246   switch (ext_opcode) {
247   case Fabs:
248   case SAbs:
249   case UAbs:
250   case SAdd_sat:
251   case UAdd_sat:
252   case Ceil:
253   case Cos:
254   case Exp2:
255   case Log2:
256   case Floor:
257   case Fma:
258   case Fmax:
259   case SHadd:
260   case UHadd:
261   case SMax:
262   case UMax:
263   case Fmin:
264   case SMin:
265   case UMin:
266   case Mix:
267   case Fmod:
268   case SMul_hi:
269   case UMul_hi:
270   case Popcount:
271   case Pow:
272   case Remainder:
273   case SRhadd:
274   case URhadd:
275   case Rsqrt:
276   case Sign:
277   case Sin:
278   case Sqrt:
279   case SSub_sat:
280   case USub_sat:
281   case Trunc:
282      handle_instr(b, ext_opcode, w, count, handle_alu);
283      return true;
284   case SAbs_diff:
285   case UAbs_diff:
286   case Bitselect:
287   case FClamp:
288   case SClamp:
289   case UClamp:
290   case Copysign:
291   case Cross:
292   case Degrees:
293   case Fdim:
294   case Distance:
295   case Fast_distance:
296   case Fast_length:
297   case Fast_normalize:
298   case Length:
299   case Mad:
300   case Maxmag:
301   case Minmag:
302   case Nan:
303   case Nextafter:
304   case Normalize:
305   case Radians:
306   case Rotate:
307   case Select:
308   case Step:
309   case Smoothstep:
310   case S_Upsample:
311   case U_Upsample:
312      handle_instr(b, ext_opcode, w, count, handle_special);
313      return true;
314   case Vloadn:
315      vtn_handle_opencl_vload(b, ext_opcode, w, count);
316      return true;
317   case Vstoren:
318      vtn_handle_opencl_vstore(b, ext_opcode, w, count);
319      return true;
320   case Printf:
321      handle_instr(b, ext_opcode, w, count, handle_printf);
322      return true;
323   case Prefetch:
324      /* TODO maybe add a nir instruction for this? */
325      return true;
326   default:
327      vtn_fail("unhandled opencl opc: %u\n", ext_opcode);
328      return false;
329   }
330}
331