1/*
2 * Copyright © 2015 Intel Corporation
3 * Copyright © 2019 Valve Corporation
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 *
24 * Authors:
25 *    Jason Ekstrand (jason@jlekstrand.net)
26 *    Samuel Pitoiset (samuel.pitoiset@gmail.com>
27 */
28
29#include "nir.h"
30#include "nir_builder.h"
31
32static nir_ssa_def *
33lower_frexp_sig(nir_builder *b, nir_ssa_def *x)
34{
35   nir_ssa_def *abs_x = nir_fabs(b, x);
36   nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
37   nir_ssa_def *sign_mantissa_mask, *exponent_value;
38   nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero);
39
40   switch (x->bit_size) {
41   case 16:
42      /* Half-precision floating-point values are stored as
43       *   1 sign bit;
44       *   5 exponent bits;
45       *   10 mantissa bits.
46       *
47       * An exponent shift of 10 will shift the mantissa out, leaving only the
48       * exponent and sign bit (which itself may be zero, if the absolute value
49       * was taken before the bitcast and shift).
50       */
51      sign_mantissa_mask = nir_imm_intN_t(b, 0x83ffu, 16);
52      /* Exponent of floating-point values in the range [0.5, 1.0). */
53      exponent_value = nir_imm_intN_t(b, 0x3800u, 16);
54      break;
55   case 32:
56      /* Single-precision floating-point values are stored as
57       *   1 sign bit;
58       *   8 exponent bits;
59       *   23 mantissa bits.
60       *
61       * An exponent shift of 23 will shift the mantissa out, leaving only the
62       * exponent and sign bit (which itself may be zero, if the absolute value
63       * was taken before the bitcast and shift.
64       */
65      sign_mantissa_mask = nir_imm_int(b, 0x807fffffu);
66      /* Exponent of floating-point values in the range [0.5, 1.0). */
67      exponent_value = nir_imm_int(b, 0x3f000000u);
68      break;
69   case 64:
70      /* Double-precision floating-point values are stored as
71       *   1 sign bit;
72       *   11 exponent bits;
73       *   52 mantissa bits.
74       *
75       * An exponent shift of 20 will shift the remaining mantissa bits out,
76       * leaving only the exponent and sign bit (which itself may be zero, if
77       * the absolute value was taken before the bitcast and shift.
78       */
79      sign_mantissa_mask = nir_imm_int(b, 0x800fffffu);
80      /* Exponent of floating-point values in the range [0.5, 1.0). */
81      exponent_value = nir_imm_int(b, 0x3fe00000u);
82      break;
83   default:
84      unreachable("Invalid bitsize");
85   }
86
87   if (x->bit_size == 64) {
88      /* We only need to deal with the exponent so first we extract the upper
89       * 32 bits using nir_unpack_64_2x32_split_y.
90       */
91      nir_ssa_def *upper_x = nir_unpack_64_2x32_split_y(b, x);
92      nir_ssa_def *zero32 = nir_imm_int(b, 0);
93
94      nir_ssa_def *new_upper =
95         nir_ior(b, nir_iand(b, upper_x, sign_mantissa_mask),
96                    nir_bcsel(b, is_not_zero, exponent_value, zero32));
97
98      nir_ssa_def *lower_x = nir_unpack_64_2x32_split_x(b, x);
99
100      return nir_pack_64_2x32_split(b, lower_x, new_upper);
101   } else {
102      return nir_ior(b, nir_iand(b, x, sign_mantissa_mask),
103                        nir_bcsel(b, is_not_zero, exponent_value, zero));
104   }
105}
106
107static nir_ssa_def *
108lower_frexp_exp(nir_builder *b, nir_ssa_def *x)
109{
110   nir_ssa_def *abs_x = nir_fabs(b, x);
111   nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
112   nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero);
113   nir_ssa_def *exponent;
114
115   switch (x->bit_size) {
116   case 16: {
117      nir_ssa_def *exponent_shift = nir_imm_int(b, 10);
118      nir_ssa_def *exponent_bias = nir_imm_intN_t(b, -14, 16);
119
120      /* Significand return must be of the same type as the input, but the
121       * exponent must be a 32-bit integer.
122       */
123      exponent = nir_i2i32(b, nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
124                              nir_bcsel(b, is_not_zero, exponent_bias, zero)));
125      break;
126   }
127   case 32: {
128      nir_ssa_def *exponent_shift = nir_imm_int(b, 23);
129      nir_ssa_def *exponent_bias = nir_imm_int(b, -126);
130
131      exponent = nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
132                             nir_bcsel(b, is_not_zero, exponent_bias, zero));
133      break;
134   }
135   case 64: {
136      nir_ssa_def *exponent_shift = nir_imm_int(b, 20);
137      nir_ssa_def *exponent_bias = nir_imm_int(b, -1022);
138
139      nir_ssa_def *zero32 = nir_imm_int(b, 0);
140      nir_ssa_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x);
141
142      exponent = nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift),
143                             nir_bcsel(b, is_not_zero, exponent_bias, zero32));
144      break;
145   }
146   default:
147      unreachable("Invalid bitsize");
148   }
149
150   return exponent;
151}
152
153static bool
154lower_frexp_impl(nir_function_impl *impl)
155{
156   bool progress = false;
157
158   nir_builder b;
159   nir_builder_init(&b, impl);
160
161   nir_foreach_block(block, impl) {
162      nir_foreach_instr_safe(instr, block) {
163         if (instr->type != nir_instr_type_alu)
164            continue;
165
166         nir_alu_instr *alu_instr = nir_instr_as_alu(instr);
167         nir_ssa_def *lower;
168
169         b.cursor = nir_before_instr(instr);
170
171         switch (alu_instr->op) {
172         case nir_op_frexp_sig:
173            lower = lower_frexp_sig(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
174            break;
175         case nir_op_frexp_exp:
176            lower = lower_frexp_exp(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
177            break;
178         default:
179            continue;
180         }
181
182         nir_ssa_def_rewrite_uses(&alu_instr->dest.dest.ssa,
183                                  nir_src_for_ssa(lower));
184         nir_instr_remove(instr);
185         progress = true;
186      }
187   }
188
189   if (progress) {
190      nir_metadata_preserve(impl, nir_metadata_block_index |
191                                  nir_metadata_dominance);
192   }
193
194   return progress;
195}
196
197bool
198nir_lower_frexp(nir_shader *shader)
199{
200   bool progress = false;
201
202   nir_foreach_function(function, shader) {
203      if (function->impl)
204         progress |= lower_frexp_impl(function->impl);
205   }
206
207   return progress;
208}
209