1b8e80941Smrg#
2b8e80941Smrg# Copyright (C) 2014 Intel Corporation
3b8e80941Smrg#
4b8e80941Smrg# Permission is hereby granted, free of charge, to any person obtaining a
5b8e80941Smrg# copy of this software and associated documentation files (the "Software"),
6b8e80941Smrg# to deal in the Software without restriction, including without limitation
7b8e80941Smrg# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8b8e80941Smrg# and/or sell copies of the Software, and to permit persons to whom the
9b8e80941Smrg# Software is furnished to do so, subject to the following conditions:
10b8e80941Smrg#
11b8e80941Smrg# The above copyright notice and this permission notice (including the next
12b8e80941Smrg# paragraph) shall be included in all copies or substantial portions of the
13b8e80941Smrg# Software.
14b8e80941Smrg#
15b8e80941Smrg# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16b8e80941Smrg# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17b8e80941Smrg# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18b8e80941Smrg# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19b8e80941Smrg# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20b8e80941Smrg# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21b8e80941Smrg# IN THE SOFTWARE.
22b8e80941Smrg#
23b8e80941Smrg# Authors:
24b8e80941Smrg#    Jason Ekstrand (jason@jlekstrand.net)
25b8e80941Smrg
26b8e80941Smrgfrom __future__ import print_function
27b8e80941Smrgimport ast
28b8e80941Smrgfrom collections import defaultdict
29b8e80941Smrgimport itertools
30b8e80941Smrgimport struct
31b8e80941Smrgimport sys
32b8e80941Smrgimport mako.template
33b8e80941Smrgimport re
34b8e80941Smrgimport traceback
35b8e80941Smrg
36b8e80941Smrgfrom nir_opcodes import opcodes, type_sizes
37b8e80941Smrg
38b8e80941Smrg# These opcodes are only employed by nir_search.  This provides a mapping from
39b8e80941Smrg# opcode to destination type.
40b8e80941Smrgconv_opcode_types = {
41b8e80941Smrg    'i2f' : 'float',
42b8e80941Smrg    'u2f' : 'float',
43b8e80941Smrg    'f2f' : 'float',
44b8e80941Smrg    'f2u' : 'uint',
45b8e80941Smrg    'f2i' : 'int',
46b8e80941Smrg    'u2u' : 'uint',
47b8e80941Smrg    'i2i' : 'int',
48b8e80941Smrg    'b2f' : 'float',
49b8e80941Smrg    'b2i' : 'int',
50b8e80941Smrg    'i2b' : 'bool',
51b8e80941Smrg    'f2b' : 'bool',
52b8e80941Smrg}
53b8e80941Smrg
54b8e80941Smrgdef get_c_opcode(op):
55b8e80941Smrg      if op in conv_opcode_types:
56b8e80941Smrg         return 'nir_search_op_' + op
57b8e80941Smrg      else:
58b8e80941Smrg         return 'nir_op_' + op
59b8e80941Smrg
60b8e80941Smrg
61b8e80941Smrgif sys.version_info < (3, 0):
62b8e80941Smrg    integer_types = (int, long)
63b8e80941Smrg    string_type = unicode
64b8e80941Smrg
65b8e80941Smrgelse:
66b8e80941Smrg    integer_types = (int, )
67b8e80941Smrg    string_type = str
68b8e80941Smrg
69b8e80941Smrg_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
70b8e80941Smrg
71b8e80941Smrgdef type_bits(type_str):
72b8e80941Smrg   m = _type_re.match(type_str)
73b8e80941Smrg   assert m.group('type')
74b8e80941Smrg
75b8e80941Smrg   if m.group('bits') is None:
76b8e80941Smrg      return 0
77b8e80941Smrg   else:
78b8e80941Smrg      return int(m.group('bits'))
79b8e80941Smrg
80b8e80941Smrg# Represents a set of variables, each with a unique id
81b8e80941Smrgclass VarSet(object):
82b8e80941Smrg   def __init__(self):
83b8e80941Smrg      self.names = {}
84b8e80941Smrg      self.ids = itertools.count()
85b8e80941Smrg      self.immutable = False;
86b8e80941Smrg
87b8e80941Smrg   def __getitem__(self, name):
88b8e80941Smrg      if name not in self.names:
89b8e80941Smrg         assert not self.immutable, "Unknown replacement variable: " + name
90b8e80941Smrg         self.names[name] = next(self.ids)
91b8e80941Smrg
92b8e80941Smrg      return self.names[name]
93b8e80941Smrg
94b8e80941Smrg   def lock(self):
95b8e80941Smrg      self.immutable = True
96b8e80941Smrg
97b8e80941Smrgclass Value(object):
98b8e80941Smrg   @staticmethod
99b8e80941Smrg   def create(val, name_base, varset):
100b8e80941Smrg      if isinstance(val, bytes):
101b8e80941Smrg         val = val.decode('utf-8')
102b8e80941Smrg
103b8e80941Smrg      if isinstance(val, tuple):
104b8e80941Smrg         return Expression(val, name_base, varset)
105b8e80941Smrg      elif isinstance(val, Expression):
106b8e80941Smrg         return val
107b8e80941Smrg      elif isinstance(val, string_type):
108b8e80941Smrg         return Variable(val, name_base, varset)
109b8e80941Smrg      elif isinstance(val, (bool, float) + integer_types):
110b8e80941Smrg         return Constant(val, name_base)
111b8e80941Smrg
112b8e80941Smrg   def __init__(self, val, name, type_str):
113b8e80941Smrg      self.in_val = str(val)
114b8e80941Smrg      self.name = name
115b8e80941Smrg      self.type_str = type_str
116b8e80941Smrg
117b8e80941Smrg   def __str__(self):
118b8e80941Smrg      return self.in_val
119b8e80941Smrg
120b8e80941Smrg   def get_bit_size(self):
121b8e80941Smrg      """Get the physical bit-size that has been chosen for this value, or if
122b8e80941Smrg      there is none, the canonical value which currently represents this
123b8e80941Smrg      bit-size class. Variables will be preferred, i.e. if there are any
124b8e80941Smrg      variables in the equivalence class, the canonical value will be a
125b8e80941Smrg      variable. We do this since we'll need to know which variable each value
126b8e80941Smrg      is equivalent to when constructing the replacement expression. This is
127b8e80941Smrg      the "find" part of the union-find algorithm.
128b8e80941Smrg      """
129b8e80941Smrg      bit_size = self
130b8e80941Smrg
131b8e80941Smrg      while isinstance(bit_size, Value):
132b8e80941Smrg         if bit_size._bit_size is None:
133b8e80941Smrg            break
134b8e80941Smrg         bit_size = bit_size._bit_size
135b8e80941Smrg
136b8e80941Smrg      if bit_size is not self:
137b8e80941Smrg         self._bit_size = bit_size
138b8e80941Smrg      return bit_size
139b8e80941Smrg
140b8e80941Smrg   def set_bit_size(self, other):
141b8e80941Smrg      """Make self.get_bit_size() return what other.get_bit_size() return
142b8e80941Smrg      before calling this, or just "other" if it's a concrete bit-size. This is
143b8e80941Smrg      the "union" part of the union-find algorithm.
144b8e80941Smrg      """
145b8e80941Smrg
146b8e80941Smrg      self_bit_size = self.get_bit_size()
147b8e80941Smrg      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
148b8e80941Smrg
149b8e80941Smrg      if self_bit_size == other_bit_size:
150b8e80941Smrg         return
151b8e80941Smrg
152b8e80941Smrg      self_bit_size._bit_size = other_bit_size
153b8e80941Smrg
154b8e80941Smrg   @property
155b8e80941Smrg   def type_enum(self):
156b8e80941Smrg      return "nir_search_value_" + self.type_str
157b8e80941Smrg
158b8e80941Smrg   @property
159b8e80941Smrg   def c_type(self):
160b8e80941Smrg      return "nir_search_" + self.type_str
161b8e80941Smrg
162b8e80941Smrg   def __c_name(self, cache):
163b8e80941Smrg      if cache is not None and self.name in cache:
164b8e80941Smrg         return cache[self.name]
165b8e80941Smrg      else:
166b8e80941Smrg         return self.name
167b8e80941Smrg
168b8e80941Smrg   def c_value_ptr(self, cache):
169b8e80941Smrg      return "&{0}.value".format(self.__c_name(cache))
170b8e80941Smrg
171b8e80941Smrg   def c_ptr(self, cache):
172b8e80941Smrg      return "&{0}".format(self.__c_name(cache))
173b8e80941Smrg
174b8e80941Smrg   @property
175b8e80941Smrg   def c_bit_size(self):
176b8e80941Smrg      bit_size = self.get_bit_size()
177b8e80941Smrg      if isinstance(bit_size, int):
178b8e80941Smrg         return bit_size
179b8e80941Smrg      elif isinstance(bit_size, Variable):
180b8e80941Smrg         return -bit_size.index - 1
181b8e80941Smrg      else:
182b8e80941Smrg         # If the bit-size class is neither a variable, nor an actual bit-size, then
183b8e80941Smrg         # - If it's in the search expression, we don't need to check anything
184b8e80941Smrg         # - If it's in the replace expression, either it's ambiguous (in which
185b8e80941Smrg         # case we'd reject it), or it equals the bit-size of the search value
186b8e80941Smrg         # We represent these cases with a 0 bit-size.
187b8e80941Smrg         return 0
188b8e80941Smrg
189b8e80941Smrg   __template = mako.template.Template("""{
190b8e80941Smrg   { ${val.type_enum}, ${val.c_bit_size} },
191b8e80941Smrg% if isinstance(val, Constant):
192b8e80941Smrg   ${val.type()}, { ${val.hex()} /* ${val.value} */ },
193b8e80941Smrg% elif isinstance(val, Variable):
194b8e80941Smrg   ${val.index}, /* ${val.var_name} */
195b8e80941Smrg   ${'true' if val.is_constant else 'false'},
196b8e80941Smrg   ${val.type() or 'nir_type_invalid' },
197b8e80941Smrg   ${val.cond if val.cond else 'NULL'},
198b8e80941Smrg% elif isinstance(val, Expression):
199b8e80941Smrg   ${'true' if val.inexact else 'false'},
200b8e80941Smrg   ${val.comm_expr_idx}, ${val.comm_exprs},
201b8e80941Smrg   ${val.c_opcode()},
202b8e80941Smrg   { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
203b8e80941Smrg   ${val.cond if val.cond else 'NULL'},
204b8e80941Smrg% endif
205b8e80941Smrg};""")
206b8e80941Smrg
207b8e80941Smrg   def render(self, cache):
208b8e80941Smrg      struct_init = self.__template.render(val=self, cache=cache,
209b8e80941Smrg                                           Constant=Constant,
210b8e80941Smrg                                           Variable=Variable,
211b8e80941Smrg                                           Expression=Expression)
212b8e80941Smrg      if cache is not None and struct_init in cache:
213b8e80941Smrg         # If it's in the cache, register a name remap in the cache and render
214b8e80941Smrg         # only a comment saying it's been remapped
215b8e80941Smrg         cache[self.name] = cache[struct_init]
216b8e80941Smrg         return "/* {} -> {} in the cache */\n".format(self.name,
217b8e80941Smrg                                                       cache[struct_init])
218b8e80941Smrg      else:
219b8e80941Smrg         if cache is not None:
220b8e80941Smrg            cache[struct_init] = self.name
221b8e80941Smrg         return "static const {} {} = {}\n".format(self.c_type, self.name,
222b8e80941Smrg                                                   struct_init)
223b8e80941Smrg
224b8e80941Smrg_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
225b8e80941Smrg
226b8e80941Smrgclass Constant(Value):
227b8e80941Smrg   def __init__(self, val, name):
228b8e80941Smrg      Value.__init__(self, val, name, "constant")
229b8e80941Smrg
230b8e80941Smrg      if isinstance(val, (str)):
231b8e80941Smrg         m = _constant_re.match(val)
232b8e80941Smrg         self.value = ast.literal_eval(m.group('value'))
233b8e80941Smrg         self._bit_size = int(m.group('bits')) if m.group('bits') else None
234b8e80941Smrg      else:
235b8e80941Smrg         self.value = val
236b8e80941Smrg         self._bit_size = None
237b8e80941Smrg
238b8e80941Smrg      if isinstance(self.value, bool):
239b8e80941Smrg         assert self._bit_size is None or self._bit_size == 1
240b8e80941Smrg         self._bit_size = 1
241b8e80941Smrg
242b8e80941Smrg   def hex(self):
243b8e80941Smrg      if isinstance(self.value, (bool)):
244b8e80941Smrg         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
245b8e80941Smrg      if isinstance(self.value, integer_types):
246b8e80941Smrg         return hex(self.value)
247b8e80941Smrg      elif isinstance(self.value, float):
248b8e80941Smrg         i = struct.unpack('Q', struct.pack('d', self.value))[0]
249b8e80941Smrg         h = hex(i)
250b8e80941Smrg
251b8e80941Smrg         # On Python 2 this 'L' suffix is automatically added, but not on Python 3
252b8e80941Smrg         # Adding it explicitly makes the generated file identical, regardless
253b8e80941Smrg         # of the Python version running this script.
254b8e80941Smrg         if h[-1] != 'L' and i > sys.maxsize:
255b8e80941Smrg            h += 'L'
256b8e80941Smrg
257b8e80941Smrg         return h
258b8e80941Smrg      else:
259b8e80941Smrg         assert False
260b8e80941Smrg
261b8e80941Smrg   def type(self):
262b8e80941Smrg      if isinstance(self.value, (bool)):
263b8e80941Smrg         return "nir_type_bool"
264b8e80941Smrg      elif isinstance(self.value, integer_types):
265b8e80941Smrg         return "nir_type_int"
266b8e80941Smrg      elif isinstance(self.value, float):
267b8e80941Smrg         return "nir_type_float"
268b8e80941Smrg
269b8e80941Smrg_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
270b8e80941Smrg                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
271b8e80941Smrg                          r"(?P<cond>\([^\)]+\))?")
272b8e80941Smrg
273b8e80941Smrgclass Variable(Value):
274b8e80941Smrg   def __init__(self, val, name, varset):
275b8e80941Smrg      Value.__init__(self, val, name, "variable")
276b8e80941Smrg
277b8e80941Smrg      m = _var_name_re.match(val)
278b8e80941Smrg      assert m and m.group('name') is not None
279b8e80941Smrg
280b8e80941Smrg      self.var_name = m.group('name')
281b8e80941Smrg
282b8e80941Smrg      # Prevent common cases where someone puts quotes around a literal
283b8e80941Smrg      # constant.  If we want to support names that have numeric or
284b8e80941Smrg      # punctuation characters, we can me the first assertion more flexible.
285b8e80941Smrg      assert self.var_name.isalpha()
286b8e80941Smrg      assert self.var_name is not 'True'
287b8e80941Smrg      assert self.var_name is not 'False'
288b8e80941Smrg
289b8e80941Smrg      self.is_constant = m.group('const') is not None
290b8e80941Smrg      self.cond = m.group('cond')
291b8e80941Smrg      self.required_type = m.group('type')
292b8e80941Smrg      self._bit_size = int(m.group('bits')) if m.group('bits') else None
293b8e80941Smrg
294b8e80941Smrg      if self.required_type == 'bool':
295b8e80941Smrg         if self._bit_size is not None:
296b8e80941Smrg            assert self._bit_size in type_sizes(self.required_type)
297b8e80941Smrg         else:
298b8e80941Smrg            self._bit_size = 1
299b8e80941Smrg
300b8e80941Smrg      if self.required_type is not None:
301b8e80941Smrg         assert self.required_type in ('float', 'bool', 'int', 'uint')
302b8e80941Smrg
303b8e80941Smrg      self.index = varset[self.var_name]
304b8e80941Smrg
305b8e80941Smrg   def type(self):
306b8e80941Smrg      if self.required_type == 'bool':
307b8e80941Smrg         return "nir_type_bool"
308b8e80941Smrg      elif self.required_type in ('int', 'uint'):
309b8e80941Smrg         return "nir_type_int"
310b8e80941Smrg      elif self.required_type == 'float':
311b8e80941Smrg         return "nir_type_float"
312b8e80941Smrg
313b8e80941Smrg_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
314b8e80941Smrg                        r"(?P<cond>\([^\)]+\))?")
315b8e80941Smrg
316b8e80941Smrgclass Expression(Value):
317b8e80941Smrg   def __init__(self, expr, name_base, varset):
318b8e80941Smrg      Value.__init__(self, expr, name_base, "expression")
319b8e80941Smrg      assert isinstance(expr, tuple)
320b8e80941Smrg
321b8e80941Smrg      m = _opcode_re.match(expr[0])
322b8e80941Smrg      assert m and m.group('opcode') is not None
323b8e80941Smrg
324b8e80941Smrg      self.opcode = m.group('opcode')
325b8e80941Smrg      self._bit_size = int(m.group('bits')) if m.group('bits') else None
326b8e80941Smrg      self.inexact = m.group('inexact') is not None
327b8e80941Smrg      self.cond = m.group('cond')
328b8e80941Smrg      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
329b8e80941Smrg                       for (i, src) in enumerate(expr[1:]) ]
330b8e80941Smrg
331b8e80941Smrg      if self.opcode in conv_opcode_types:
332b8e80941Smrg         assert self._bit_size is None, \
333b8e80941Smrg                'Expression cannot use an unsized conversion opcode with ' \
334b8e80941Smrg                'an explicit size; that\'s silly.'
335b8e80941Smrg
336b8e80941Smrg      self.__index_comm_exprs(0)
337b8e80941Smrg
338b8e80941Smrg   def __index_comm_exprs(self, base_idx):
339b8e80941Smrg      """Recursively count and index commutative expressions
340b8e80941Smrg      """
341b8e80941Smrg      self.comm_exprs = 0
342b8e80941Smrg      if self.opcode not in conv_opcode_types and \
343b8e80941Smrg         "commutative" in opcodes[self.opcode].algebraic_properties:
344b8e80941Smrg         self.comm_expr_idx = base_idx
345b8e80941Smrg         self.comm_exprs += 1
346b8e80941Smrg      else:
347b8e80941Smrg         self.comm_expr_idx = -1
348b8e80941Smrg
349b8e80941Smrg      for s in self.sources:
350b8e80941Smrg         if isinstance(s, Expression):
351b8e80941Smrg            s.__index_comm_exprs(base_idx + self.comm_exprs)
352b8e80941Smrg            self.comm_exprs += s.comm_exprs
353b8e80941Smrg
354b8e80941Smrg      return self.comm_exprs
355b8e80941Smrg
356b8e80941Smrg   def c_opcode(self):
357b8e80941Smrg      return get_c_opcode(self.opcode)
358b8e80941Smrg
359b8e80941Smrg   def render(self, cache):
360b8e80941Smrg      srcs = "\n".join(src.render(cache) for src in self.sources)
361b8e80941Smrg      return srcs + super(Expression, self).render(cache)
362b8e80941Smrg
363b8e80941Smrgclass BitSizeValidator(object):
364b8e80941Smrg   """A class for validating bit sizes of expressions.
365b8e80941Smrg
366b8e80941Smrg   NIR supports multiple bit-sizes on expressions in order to handle things
367b8e80941Smrg   such as fp64.  The source and destination of every ALU operation is
368b8e80941Smrg   assigned a type and that type may or may not specify a bit size.  Sources
369b8e80941Smrg   and destinations whose type does not specify a bit size are considered
370b8e80941Smrg   "unsized" and automatically take on the bit size of the corresponding
371b8e80941Smrg   register or SSA value.  NIR has two simple rules for bit sizes that are
372b8e80941Smrg   validated by nir_validator:
373b8e80941Smrg
374b8e80941Smrg    1) A given SSA def or register has a single bit size that is respected by
375b8e80941Smrg       everything that reads from it or writes to it.
376b8e80941Smrg
377b8e80941Smrg    2) The bit sizes of all unsized inputs/outputs on any given ALU
378b8e80941Smrg       instruction must match.  They need not match the sized inputs or
379b8e80941Smrg       outputs but they must match each other.
380b8e80941Smrg
381b8e80941Smrg   In order to keep nir_algebraic relatively simple and easy-to-use,
382b8e80941Smrg   nir_search supports a type of bit-size inference based on the two rules
383b8e80941Smrg   above.  This is similar to type inference in many common programming
384b8e80941Smrg   languages.  If, for instance, you are constructing an add operation and you
385b8e80941Smrg   know the second source is 16-bit, then you know that the other source and
386b8e80941Smrg   the destination must also be 16-bit.  There are, however, cases where this
387b8e80941Smrg   inference can be ambiguous or contradictory.  Consider, for instance, the
388b8e80941Smrg   following transformation:
389b8e80941Smrg
390b8e80941Smrg   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
391b8e80941Smrg
392b8e80941Smrg   This transformation can potentially cause a problem because usub_borrow is
393b8e80941Smrg   well-defined for any bit-size of integer.  However, b2i always generates a
394b8e80941Smrg   32-bit result so it could end up replacing a 64-bit expression with one
395b8e80941Smrg   that takes two 64-bit values and produces a 32-bit value.  As another
396b8e80941Smrg   example, consider this expression:
397b8e80941Smrg
398b8e80941Smrg   (('bcsel', a, b, 0), ('iand', a, b))
399b8e80941Smrg
400b8e80941Smrg   In this case, in the search expression a must be 32-bit but b can
401b8e80941Smrg   potentially have any bit size.  If we had a 64-bit b value, we would end up
402b8e80941Smrg   trying to and a 32-bit value with a 64-bit value which would be invalid
403b8e80941Smrg
404b8e80941Smrg   This class solves that problem by providing a validation layer that proves
405b8e80941Smrg   that a given search-and-replace operation is 100% well-defined before we
406b8e80941Smrg   generate any code.  This ensures that bugs are caught at compile time
407b8e80941Smrg   rather than at run time.
408b8e80941Smrg
409b8e80941Smrg   Each value maintains a "bit-size class", which is either an actual bit size
410b8e80941Smrg   or an equivalence class with other values that must have the same bit size.
411b8e80941Smrg   The validator works by combining bit-size classes with each other according
412b8e80941Smrg   to the NIR rules outlined above, checking that there are no inconsistencies.
413b8e80941Smrg   When doing this for the replacement expression, we make sure to never change
414b8e80941Smrg   the equivalence class of any of the search values. We could make the example
415b8e80941Smrg   transforms above work by doing some extra run-time checking of the search
416b8e80941Smrg   expression, but we make the user specify those constraints themselves, to
417b8e80941Smrg   avoid any surprises. Since the replacement bitsizes can only be connected to
418b8e80941Smrg   the source bitsize via variables (variables must have the same bitsize in
419b8e80941Smrg   the source and replacment expressions) or the roots of the expression (the
420b8e80941Smrg   replacement expression must produce the same bit size as the search
421b8e80941Smrg   expression), we prevent merging a variable with anything when processing the
422b8e80941Smrg   replacement expression, or specializing the search bitsize
423b8e80941Smrg   with anything. The former prevents
424b8e80941Smrg
425b8e80941Smrg   (('bcsel', a, b, 0), ('iand', a, b))
426b8e80941Smrg
427b8e80941Smrg   from being allowed, since we'd have to merge the bitsizes for a and b due to
428b8e80941Smrg   the 'iand', while the latter prevents
429b8e80941Smrg
430b8e80941Smrg   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
431b8e80941Smrg
432b8e80941Smrg   from being allowed, since the search expression has the bit size of a and b,
433b8e80941Smrg   which can't be specialized to 32 which is the bitsize of the replace
434b8e80941Smrg   expression. It also prevents something like:
435b8e80941Smrg
436b8e80941Smrg   (('b2i', ('i2b', a)), ('ineq', a, 0))
437b8e80941Smrg
438b8e80941Smrg   since the bitsize of 'b2i', which can be anything, can't be specialized to
439b8e80941Smrg   the bitsize of a.
440b8e80941Smrg
441b8e80941Smrg   After doing all this, we check that every subexpression of the replacement
442b8e80941Smrg   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
443b8e80941Smrg   of the search expresssion, since those are the things that are known when
444b8e80941Smrg   constructing the replacement expresssion. Finally, we record the bitsize
445b8e80941Smrg   needed in nir_search_value so that we know what to do when building the
446b8e80941Smrg   replacement expression.
447b8e80941Smrg   """
448b8e80941Smrg
449b8e80941Smrg   def __init__(self, varset):
450b8e80941Smrg      self._var_classes = [None] * len(varset.names)
451b8e80941Smrg
452b8e80941Smrg   def compare_bitsizes(self, a, b):
453b8e80941Smrg      """Determines which bitsize class is a specialization of the other, or
454b8e80941Smrg      whether neither is. When we merge two different bitsizes, the
455b8e80941Smrg      less-specialized bitsize always points to the more-specialized one, so
456b8e80941Smrg      that calling get_bit_size() always gets you the most specialized bitsize.
457b8e80941Smrg      The specialization partial order is given by:
458b8e80941Smrg      - Physical bitsizes are always the most specialized, and a different
459b8e80941Smrg        bitsize can never specialize another.
460b8e80941Smrg      - In the search expression, variables can always be specialized to each
461b8e80941Smrg        other and to physical bitsizes. In the replace expression, we disallow
462b8e80941Smrg        this to avoid adding extra constraints to the search expression that
463b8e80941Smrg        the user didn't specify.
464b8e80941Smrg      - Expressions and constants without a bitsize can always be specialized to
465b8e80941Smrg        each other and variables, but not the other way around.
466b8e80941Smrg
467b8e80941Smrg        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
468b8e80941Smrg        and None if they are not comparable (neither a <= b nor b <= a).
469b8e80941Smrg      """
470b8e80941Smrg      if isinstance(a, int):
471b8e80941Smrg         if isinstance(b, int):
472b8e80941Smrg            return 0 if a == b else None
473b8e80941Smrg         elif isinstance(b, Variable):
474b8e80941Smrg            return -1 if self.is_search else None
475b8e80941Smrg         else:
476b8e80941Smrg            return -1
477b8e80941Smrg      elif isinstance(a, Variable):
478b8e80941Smrg         if isinstance(b, int):
479b8e80941Smrg            return 1 if self.is_search else None
480b8e80941Smrg         elif isinstance(b, Variable):
481b8e80941Smrg            return 0 if self.is_search or a.index == b.index else None
482b8e80941Smrg         else:
483b8e80941Smrg            return -1
484b8e80941Smrg      else:
485b8e80941Smrg         if isinstance(b, int):
486b8e80941Smrg            return 1
487b8e80941Smrg         elif isinstance(b, Variable):
488b8e80941Smrg            return 1
489b8e80941Smrg         else:
490b8e80941Smrg            return 0
491b8e80941Smrg
492b8e80941Smrg   def unify_bit_size(self, a, b, error_msg):
493b8e80941Smrg      """Record that a must have the same bit-size as b. If both
494b8e80941Smrg      have been assigned conflicting physical bit-sizes, call "error_msg" with
495b8e80941Smrg      the bit-sizes of self and other to get a message and raise an error.
496b8e80941Smrg      In the replace expression, disallow merging variables with other
497b8e80941Smrg      variables and physical bit-sizes as well.
498b8e80941Smrg      """
499b8e80941Smrg      a_bit_size = a.get_bit_size()
500b8e80941Smrg      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
501b8e80941Smrg
502b8e80941Smrg      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
503b8e80941Smrg
504b8e80941Smrg      assert cmp_result is not None, \
505b8e80941Smrg         error_msg(a_bit_size, b_bit_size)
506b8e80941Smrg
507b8e80941Smrg      if cmp_result < 0:
508b8e80941Smrg         b_bit_size.set_bit_size(a)
509b8e80941Smrg      elif not isinstance(a_bit_size, int):
510b8e80941Smrg         a_bit_size.set_bit_size(b)
511b8e80941Smrg
512b8e80941Smrg   def merge_variables(self, val):
513b8e80941Smrg      """Perform the first part of type inference by merging all the different
514b8e80941Smrg      uses of the same variable. We always do this as if we're in the search
515b8e80941Smrg      expression, even if we're actually not, since otherwise we'd get errors
516b8e80941Smrg      if the search expression specified some constraint but the replace
517b8e80941Smrg      expression didn't, because we'd be merging a variable and a constant.
518b8e80941Smrg      """
519b8e80941Smrg      if isinstance(val, Variable):
520b8e80941Smrg         if self._var_classes[val.index] is None:
521b8e80941Smrg            self._var_classes[val.index] = val
522b8e80941Smrg         else:
523b8e80941Smrg            other = self._var_classes[val.index]
524b8e80941Smrg            self.unify_bit_size(other, val,
525b8e80941Smrg                  lambda other_bit_size, bit_size:
526b8e80941Smrg                     'Variable {} has conflicting bit size requirements: ' \
527b8e80941Smrg                     'it must have bit size {} and {}'.format(
528b8e80941Smrg                        val.var_name, other_bit_size, bit_size))
529b8e80941Smrg      elif isinstance(val, Expression):
530b8e80941Smrg         for src in val.sources:
531b8e80941Smrg            self.merge_variables(src)
532b8e80941Smrg
533b8e80941Smrg   def validate_value(self, val):
534b8e80941Smrg      """Validate the an expression by performing classic Hindley-Milner
535b8e80941Smrg      type inference on bitsizes. This will detect if there are any conflicting
536b8e80941Smrg      requirements, and unify variables so that we know which variables must
537b8e80941Smrg      have the same bitsize. If we're operating on the replace expression, we
538b8e80941Smrg      will refuse to merge different variables together or merge a variable
539b8e80941Smrg      with a constant, in order to prevent surprises due to rules unexpectedly
540b8e80941Smrg      not matching at runtime.
541b8e80941Smrg      """
542b8e80941Smrg      if not isinstance(val, Expression):
543b8e80941Smrg         return
544b8e80941Smrg
545b8e80941Smrg      # Generic conversion ops are special in that they have a single unsized
546b8e80941Smrg      # source and an unsized destination and the two don't have to match.
547b8e80941Smrg      # This means there's no validation or unioning to do here besides the
548b8e80941Smrg      # len(val.sources) check.
549b8e80941Smrg      if val.opcode in conv_opcode_types:
550b8e80941Smrg         assert len(val.sources) == 1, \
551b8e80941Smrg            "Expression {} has {} sources, expected 1".format(
552b8e80941Smrg               val, len(val.sources))
553b8e80941Smrg         self.validate_value(val.sources[0])
554b8e80941Smrg         return
555b8e80941Smrg
556b8e80941Smrg      nir_op = opcodes[val.opcode]
557b8e80941Smrg      assert len(val.sources) == nir_op.num_inputs, \
558b8e80941Smrg         "Expression {} has {} sources, expected {}".format(
559b8e80941Smrg            val, len(val.sources), nir_op.num_inputs)
560b8e80941Smrg
561b8e80941Smrg      for src in val.sources:
562b8e80941Smrg         self.validate_value(src)
563b8e80941Smrg
564b8e80941Smrg      dst_type_bits = type_bits(nir_op.output_type)
565b8e80941Smrg
566b8e80941Smrg      # First, unify all the sources. That way, an error coming up because two
567b8e80941Smrg      # sources have an incompatible bit-size won't produce an error message
568b8e80941Smrg      # involving the destination.
569b8e80941Smrg      first_unsized_src = None
570b8e80941Smrg      for src_type, src in zip(nir_op.input_types, val.sources):
571b8e80941Smrg         src_type_bits = type_bits(src_type)
572b8e80941Smrg         if src_type_bits == 0:
573b8e80941Smrg            if first_unsized_src is None:
574b8e80941Smrg               first_unsized_src = src
575b8e80941Smrg               continue
576b8e80941Smrg
577b8e80941Smrg            if self.is_search:
578b8e80941Smrg               self.unify_bit_size(first_unsized_src, src,
579b8e80941Smrg                  lambda first_unsized_src_bit_size, src_bit_size:
580b8e80941Smrg                     'Source {} of {} must have bit size {}, while source {} ' \
581b8e80941Smrg                     'must have incompatible bit size {}'.format(
582b8e80941Smrg                        first_unsized_src, val, first_unsized_src_bit_size,
583b8e80941Smrg                        src, src_bit_size))
584b8e80941Smrg            else:
585b8e80941Smrg               self.unify_bit_size(first_unsized_src, src,
586b8e80941Smrg                  lambda first_unsized_src_bit_size, src_bit_size:
587b8e80941Smrg                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
588b8e80941Smrg                     'of {} may not have the same bit size when building the ' \
589b8e80941Smrg                     'replacement expression.'.format(
590b8e80941Smrg                        first_unsized_src, first_unsized_src_bit_size, src,
591b8e80941Smrg                        src_bit_size, val))
592b8e80941Smrg         else:
593b8e80941Smrg            if self.is_search:
594b8e80941Smrg               self.unify_bit_size(src, src_type_bits,
595b8e80941Smrg                  lambda src_bit_size, unused:
596b8e80941Smrg                     '{} must have {} bits, but as a source of nir_op_{} '\
597b8e80941Smrg                     'it must have {} bits'.format(
598b8e80941Smrg                        src, src_bit_size, nir_op.name, src_type_bits))
599b8e80941Smrg            else:
600b8e80941Smrg               self.unify_bit_size(src, src_type_bits,
601b8e80941Smrg                  lambda src_bit_size, unused:
602b8e80941Smrg                     '{} has the bit size of {}, but as a source of ' \
603b8e80941Smrg                     'nir_op_{} it must have {} bits, which may not be the ' \
604b8e80941Smrg                     'same'.format(
605b8e80941Smrg                        src, src_bit_size, nir_op.name, src_type_bits))
606b8e80941Smrg
607b8e80941Smrg      if dst_type_bits == 0:
608b8e80941Smrg         if first_unsized_src is not None:
609b8e80941Smrg            if self.is_search:
610b8e80941Smrg               self.unify_bit_size(val, first_unsized_src,
611b8e80941Smrg                  lambda val_bit_size, src_bit_size:
612b8e80941Smrg                     '{} must have the bit size of {}, while its source {} ' \
613b8e80941Smrg                     'must have incompatible bit size {}'.format(
614b8e80941Smrg                        val, val_bit_size, first_unsized_src, src_bit_size))
615b8e80941Smrg            else:
616b8e80941Smrg               self.unify_bit_size(val, first_unsized_src,
617b8e80941Smrg                  lambda val_bit_size, src_bit_size:
618b8e80941Smrg                     '{} must have {} bits, but its source {} ' \
619b8e80941Smrg                     '(bit size of {}) may not have that bit size ' \
620b8e80941Smrg                     'when building the replacement.'.format(
621b8e80941Smrg                        val, val_bit_size, first_unsized_src, src_bit_size))
622b8e80941Smrg      else:
623b8e80941Smrg         self.unify_bit_size(val, dst_type_bits,
624b8e80941Smrg            lambda dst_bit_size, unused:
625b8e80941Smrg               '{} must have {} bits, but as a destination of nir_op_{} ' \
626b8e80941Smrg               'it must have {} bits'.format(
627b8e80941Smrg                  val, dst_bit_size, nir_op.name, dst_type_bits))
628b8e80941Smrg
629b8e80941Smrg   def validate_replace(self, val, search):
630b8e80941Smrg      bit_size = val.get_bit_size()
631b8e80941Smrg      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
632b8e80941Smrg            bit_size == search.get_bit_size(), \
633b8e80941Smrg            'Ambiguous bit size for replacement value {}: ' \
634b8e80941Smrg            'it cannot be deduced from a variable, a fixed bit size ' \
635b8e80941Smrg            'somewhere, or the search expression.'.format(val)
636b8e80941Smrg
637b8e80941Smrg      if isinstance(val, Expression):
638b8e80941Smrg         for src in val.sources:
639b8e80941Smrg            self.validate_replace(src, search)
640b8e80941Smrg
641b8e80941Smrg   def validate(self, search, replace):
642b8e80941Smrg      self.is_search = True
643b8e80941Smrg      self.merge_variables(search)
644b8e80941Smrg      self.merge_variables(replace)
645b8e80941Smrg      self.validate_value(search)
646b8e80941Smrg
647b8e80941Smrg      self.is_search = False
648b8e80941Smrg      self.validate_value(replace)
649b8e80941Smrg
650b8e80941Smrg      # Check that search is always more specialized than replace. Note that
651b8e80941Smrg      # we're doing this in replace mode, disallowing merging variables.
652b8e80941Smrg      search_bit_size = search.get_bit_size()
653b8e80941Smrg      replace_bit_size = replace.get_bit_size()
654b8e80941Smrg      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
655b8e80941Smrg
656b8e80941Smrg      assert cmp_result is not None and cmp_result <= 0, \
657b8e80941Smrg         'The search expression bit size {} and replace expression ' \
658b8e80941Smrg         'bit size {} may not be the same'.format(
659b8e80941Smrg               search_bit_size, replace_bit_size)
660b8e80941Smrg
661b8e80941Smrg      replace.set_bit_size(search)
662b8e80941Smrg
663b8e80941Smrg      self.validate_replace(replace, search)
664b8e80941Smrg
665b8e80941Smrg_optimization_ids = itertools.count()
666b8e80941Smrg
667b8e80941Smrgcondition_list = ['true']
668b8e80941Smrg
669b8e80941Smrgclass SearchAndReplace(object):
670b8e80941Smrg   def __init__(self, transform):
671b8e80941Smrg      self.id = next(_optimization_ids)
672b8e80941Smrg
673b8e80941Smrg      search = transform[0]
674b8e80941Smrg      replace = transform[1]
675b8e80941Smrg      if len(transform) > 2:
676b8e80941Smrg         self.condition = transform[2]
677b8e80941Smrg      else:
678b8e80941Smrg         self.condition = 'true'
679b8e80941Smrg
680b8e80941Smrg      if self.condition not in condition_list:
681b8e80941Smrg         condition_list.append(self.condition)
682b8e80941Smrg      self.condition_index = condition_list.index(self.condition)
683b8e80941Smrg
684b8e80941Smrg      varset = VarSet()
685b8e80941Smrg      if isinstance(search, Expression):
686b8e80941Smrg         self.search = search
687b8e80941Smrg      else:
688b8e80941Smrg         self.search = Expression(search, "search{0}".format(self.id), varset)
689b8e80941Smrg
690b8e80941Smrg      varset.lock()
691b8e80941Smrg
692b8e80941Smrg      if isinstance(replace, Value):
693b8e80941Smrg         self.replace = replace
694b8e80941Smrg      else:
695b8e80941Smrg         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
696b8e80941Smrg
697b8e80941Smrg      BitSizeValidator(varset).validate(self.search, self.replace)
698b8e80941Smrg
699b8e80941Smrgclass TreeAutomaton(object):
700b8e80941Smrg   """This class calculates a bottom-up tree automaton to quickly search for
701b8e80941Smrg   the left-hand sides of tranforms. Tree automatons are a generalization of
702b8e80941Smrg   classical NFA's and DFA's, where the transition function determines the
703b8e80941Smrg   state of the parent node based on the state of its children. We construct a
704b8e80941Smrg   deterministic automaton to match patterns, using a similar algorithm to the
705b8e80941Smrg   classical NFA to DFA construction. At the moment, it only matches opcodes
706b8e80941Smrg   and constants (without checking the actual value), leaving more detailed
707b8e80941Smrg   checking to the search function which actually checks the leaves. The
708b8e80941Smrg   automaton acts as a quick filter for the search function, requiring only n
709b8e80941Smrg   + 1 table lookups for each n-source operation. The implementation is based
710b8e80941Smrg   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
711b8e80941Smrg   In the language of that reference, this is a frontier-to-root deterministic
712b8e80941Smrg   automaton using only symbol filtering. The filtering is crucial to reduce
713b8e80941Smrg   both the time taken to generate the tables and the size of the tables.
714b8e80941Smrg   """
715b8e80941Smrg   def __init__(self, transforms):
716b8e80941Smrg      self.patterns = [t.search for t in transforms]
717b8e80941Smrg      self._compute_items()
718b8e80941Smrg      self._build_table()
719b8e80941Smrg      #print('num items: {}'.format(len(set(self.items.values()))))
720b8e80941Smrg      #print('num states: {}'.format(len(self.states)))
721b8e80941Smrg      #for state, patterns in zip(self.states, self.patterns):
722b8e80941Smrg      #   print('{}: num patterns: {}'.format(state, len(patterns)))
723b8e80941Smrg
724b8e80941Smrg   class IndexMap(object):
725b8e80941Smrg      """An indexed list of objects, where one can either lookup an object by
726b8e80941Smrg      index or find the index associated to an object quickly using a hash
727b8e80941Smrg      table. Compared to a list, it has a constant time index(). Compared to a
728b8e80941Smrg      set, it provides a stable iteration order.
729b8e80941Smrg      """
730b8e80941Smrg      def __init__(self, iterable=()):
731b8e80941Smrg         self.objects = []
732b8e80941Smrg         self.map = {}
733b8e80941Smrg         for obj in iterable:
734b8e80941Smrg            self.add(obj)
735b8e80941Smrg
736b8e80941Smrg      def __getitem__(self, i):
737b8e80941Smrg         return self.objects[i]
738b8e80941Smrg
739b8e80941Smrg      def __contains__(self, obj):
740b8e80941Smrg         return obj in self.map
741b8e80941Smrg
742b8e80941Smrg      def __len__(self):
743b8e80941Smrg         return len(self.objects)
744b8e80941Smrg
745b8e80941Smrg      def __iter__(self):
746b8e80941Smrg         return iter(self.objects)
747b8e80941Smrg
748b8e80941Smrg      def clear(self):
749b8e80941Smrg         self.objects = []
750b8e80941Smrg         self.map.clear()
751b8e80941Smrg
752b8e80941Smrg      def index(self, obj):
753b8e80941Smrg         return self.map[obj]
754b8e80941Smrg
755b8e80941Smrg      def add(self, obj):
756b8e80941Smrg         if obj in self.map:
757b8e80941Smrg            return self.map[obj]
758b8e80941Smrg         else:
759b8e80941Smrg            index = len(self.objects)
760b8e80941Smrg            self.objects.append(obj)
761b8e80941Smrg            self.map[obj] = index
762b8e80941Smrg            return index
763b8e80941Smrg
764b8e80941Smrg      def __repr__(self):
765b8e80941Smrg         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
766b8e80941Smrg
767b8e80941Smrg   class Item(object):
768b8e80941Smrg      """This represents an "item" in the language of "Tree Automatons." This
769b8e80941Smrg      is just a subtree of some pattern, which represents a potential partial
770b8e80941Smrg      match at runtime. We deduplicate them, so that identical subtrees of
771b8e80941Smrg      different patterns share the same object, and store some extra
772b8e80941Smrg      information needed for the main algorithm as well.
773b8e80941Smrg      """
774b8e80941Smrg      def __init__(self, opcode, children):
775b8e80941Smrg         self.opcode = opcode
776b8e80941Smrg         self.children = children
777b8e80941Smrg         # These are the indices of patterns for which this item is the root node.
778b8e80941Smrg         self.patterns = []
779b8e80941Smrg         # This the set of opcodes for parents of this item. Used to speed up
780b8e80941Smrg         # filtering.
781b8e80941Smrg         self.parent_ops = set()
782b8e80941Smrg
783b8e80941Smrg      def __str__(self):
784b8e80941Smrg         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
785b8e80941Smrg
786b8e80941Smrg      def __repr__(self):
787b8e80941Smrg         return str(self)
788b8e80941Smrg
789b8e80941Smrg   def _compute_items(self):
790b8e80941Smrg      """Build a set of all possible items, deduplicating them."""
791b8e80941Smrg      # This is a map from (opcode, sources) to item.
792b8e80941Smrg      self.items = {}
793b8e80941Smrg
794b8e80941Smrg      # The set of all opcodes used by the patterns. Used later to avoid
795b8e80941Smrg      # building and emitting all the tables for opcodes that aren't used.
796b8e80941Smrg      self.opcodes = self.IndexMap()
797b8e80941Smrg
798b8e80941Smrg      def get_item(opcode, children, pattern=None):
799b8e80941Smrg         commutative = len(children) == 2 \
800b8e80941Smrg               and "commutative" in opcodes[opcode].algebraic_properties
801b8e80941Smrg         item = self.items.setdefault((opcode, children),
802b8e80941Smrg                                      self.Item(opcode, children))
803b8e80941Smrg         if commutative:
804b8e80941Smrg            self.items[opcode, (children[1], children[0])] = item
805b8e80941Smrg         if pattern is not None:
806b8e80941Smrg            item.patterns.append(pattern)
807b8e80941Smrg         return item
808b8e80941Smrg
809b8e80941Smrg      self.wildcard = get_item("__wildcard", ())
810b8e80941Smrg      self.const = get_item("__const", ())
811b8e80941Smrg
812b8e80941Smrg      def process_subpattern(src, pattern=None):
813b8e80941Smrg         if isinstance(src, Constant):
814b8e80941Smrg            # Note: we throw away the actual constant value!
815b8e80941Smrg            return self.const
816b8e80941Smrg         elif isinstance(src, Variable):
817b8e80941Smrg            if src.is_constant:
818b8e80941Smrg               return self.const
819b8e80941Smrg            else:
820b8e80941Smrg               # Note: we throw away which variable it is here! This special
821b8e80941Smrg               # item is equivalent to nu in "Tree Automatons."
822b8e80941Smrg               return self.wildcard
823b8e80941Smrg         else:
824b8e80941Smrg            assert isinstance(src, Expression)
825b8e80941Smrg            opcode = src.opcode
826b8e80941Smrg            stripped = opcode.rstrip('0123456789')
827b8e80941Smrg            if stripped in conv_opcode_types:
828b8e80941Smrg               # Matches that use conversion opcodes with a specific type,
829b8e80941Smrg               # like f2b1, are tricky.  Either we construct the automaton to
830b8e80941Smrg               # match specific NIR opcodes like nir_op_f2b1, in which case we
831b8e80941Smrg               # need to create separate items for each possible NIR opcode
832b8e80941Smrg               # for patterns that have a generic opcode like f2b, or we
833b8e80941Smrg               # construct it to match the search opcode, in which case we
834b8e80941Smrg               # need to map f2b1 to f2b when constructing the automaton. Here
835b8e80941Smrg               # we do the latter.
836b8e80941Smrg               opcode = stripped
837b8e80941Smrg            self.opcodes.add(opcode)
838b8e80941Smrg            children = tuple(process_subpattern(c) for c in src.sources)
839b8e80941Smrg            item = get_item(opcode, children, pattern)
840b8e80941Smrg            for i, child in enumerate(children):
841b8e80941Smrg               child.parent_ops.add(opcode)
842b8e80941Smrg            return item
843b8e80941Smrg
844b8e80941Smrg      for i, pattern in enumerate(self.patterns):
845b8e80941Smrg         process_subpattern(pattern, i)
846b8e80941Smrg
847b8e80941Smrg   def _build_table(self):
848b8e80941Smrg      """This is the core algorithm which builds up the transition table. It
849b8e80941Smrg      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
850b8e80941Smrg      Comp_a and Filt_{a,i} using integers to identify match sets." It
851b8e80941Smrg      simultaneously builds up a list of all possible "match sets" or
852b8e80941Smrg      "states", where each match set represents the set of Item's that match a
853b8e80941Smrg      given instruction, and builds up the transition table between states.
854b8e80941Smrg      """
855b8e80941Smrg      # Map from opcode + filtered state indices to transitioned state.
856b8e80941Smrg      self.table = defaultdict(dict)
857b8e80941Smrg      # Bijection from state to index. q in the original algorithm is
858b8e80941Smrg      # len(self.states)
859b8e80941Smrg      self.states = self.IndexMap()
860b8e80941Smrg      # List of pattern matches for each state index.
861b8e80941Smrg      self.state_patterns = []
862b8e80941Smrg      # Map from state index to filtered state index for each opcode.
863b8e80941Smrg      self.filter = defaultdict(list)
864b8e80941Smrg      # Bijections from filtered state to filtered state index for each
865b8e80941Smrg      # opcode, called the "representor sets" in the original algorithm.
866b8e80941Smrg      # q_{a,j} in the original algorithm is len(self.rep[op]).
867b8e80941Smrg      self.rep = defaultdict(self.IndexMap)
868b8e80941Smrg
869b8e80941Smrg      # Everything in self.states with a index at least worklist_index is part
870b8e80941Smrg      # of the worklist of newly created states. There is also a worklist of
871b8e80941Smrg      # newly fitered states for each opcode, for which worklist_indices
872b8e80941Smrg      # serves a similar purpose. worklist_index corresponds to p in the
873b8e80941Smrg      # original algorithm, while worklist_indices is p_{a,j} (although since
874b8e80941Smrg      # we only filter by opcode/symbol, it's really just p_a).
875b8e80941Smrg      self.worklist_index = 0
876b8e80941Smrg      worklist_indices = defaultdict(lambda: 0)
877b8e80941Smrg
878b8e80941Smrg      # This is the set of opcodes for which the filtered worklist is non-empty.
879b8e80941Smrg      # It's used to avoid scanning opcodes for which there is nothing to
880b8e80941Smrg      # process when building the transition table. It corresponds to new_a in
881b8e80941Smrg      # the original algorithm.
882b8e80941Smrg      new_opcodes = self.IndexMap()
883b8e80941Smrg
884b8e80941Smrg      # Process states on the global worklist, filtering them for each opcode,
885b8e80941Smrg      # updating the filter tables, and updating the filtered worklists if any
886b8e80941Smrg      # new filtered states are found. Similar to ComputeRepresenterSets() in
887b8e80941Smrg      # the original algorithm, although that only processes a single state.
888b8e80941Smrg      def process_new_states():
889b8e80941Smrg         while self.worklist_index < len(self.states):
890b8e80941Smrg            state = self.states[self.worklist_index]
891b8e80941Smrg
892b8e80941Smrg            # Calculate pattern matches for this state. Each pattern is
893b8e80941Smrg            # assigned to a unique item, so we don't have to worry about
894b8e80941Smrg            # deduplicating them here. However, we do have to sort them so
895b8e80941Smrg            # that they're visited at runtime in the order they're specified
896b8e80941Smrg            # in the source.
897b8e80941Smrg            patterns = list(sorted(p for item in state for p in item.patterns))
898b8e80941Smrg            assert len(self.state_patterns) == self.worklist_index
899b8e80941Smrg            self.state_patterns.append(patterns)
900b8e80941Smrg
901b8e80941Smrg            # calculate filter table for this state, and update filtered
902b8e80941Smrg            # worklists.
903b8e80941Smrg            for op in self.opcodes:
904b8e80941Smrg               filt = self.filter[op]
905b8e80941Smrg               rep = self.rep[op]
906b8e80941Smrg               filtered = frozenset(item for item in state if \
907b8e80941Smrg                  op in item.parent_ops)
908b8e80941Smrg               if filtered in rep:
909b8e80941Smrg                  rep_index = rep.index(filtered)
910b8e80941Smrg               else:
911b8e80941Smrg                  rep_index = rep.add(filtered)
912b8e80941Smrg                  new_opcodes.add(op)
913b8e80941Smrg               assert len(filt) == self.worklist_index
914b8e80941Smrg               filt.append(rep_index)
915b8e80941Smrg            self.worklist_index += 1
916b8e80941Smrg
917b8e80941Smrg      # There are two start states: one which can only match as a wildcard,
918b8e80941Smrg      # and one which can match as a wildcard or constant. These will be the
919b8e80941Smrg      # states of intrinsics/other instructions and load_const instructions,
920b8e80941Smrg      # respectively. The indices of these must match the definitions of
921b8e80941Smrg      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
922b8e80941Smrg      # initialize things correctly.
923b8e80941Smrg      self.states.add(frozenset((self.wildcard,)))
924b8e80941Smrg      self.states.add(frozenset((self.const,self.wildcard)))
925b8e80941Smrg      process_new_states()
926b8e80941Smrg
927b8e80941Smrg      while len(new_opcodes) > 0:
928b8e80941Smrg         for op in new_opcodes:
929b8e80941Smrg            rep = self.rep[op]
930b8e80941Smrg            table = self.table[op]
931b8e80941Smrg            op_worklist_index = worklist_indices[op]
932b8e80941Smrg            if op in conv_opcode_types:
933b8e80941Smrg               num_srcs = 1
934b8e80941Smrg            else:
935b8e80941Smrg               num_srcs = opcodes[op].num_inputs
936b8e80941Smrg
937b8e80941Smrg            # Iterate over all possible source combinations where at least one
938b8e80941Smrg            # is on the worklist.
939b8e80941Smrg            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
940b8e80941Smrg               if all(src_idx < op_worklist_index for src_idx in src_indices):
941b8e80941Smrg                  continue
942b8e80941Smrg
943b8e80941Smrg               srcs = tuple(rep[src_idx] for src_idx in src_indices)
944b8e80941Smrg
945b8e80941Smrg               # Try all possible pairings of source items and add the
946b8e80941Smrg               # corresponding parent items. This is Comp_a from the paper.
947b8e80941Smrg               parent = set(self.items[op, item_srcs] for item_srcs in
948b8e80941Smrg                  itertools.product(*srcs) if (op, item_srcs) in self.items)
949b8e80941Smrg
950b8e80941Smrg               # We could always start matching something else with a
951b8e80941Smrg               # wildcard. This is Cl from the paper.
952b8e80941Smrg               parent.add(self.wildcard)
953b8e80941Smrg
954b8e80941Smrg               table[src_indices] = self.states.add(frozenset(parent))
955b8e80941Smrg            worklist_indices[op] = len(rep)
956b8e80941Smrg         new_opcodes.clear()
957b8e80941Smrg         process_new_states()
958b8e80941Smrg
959b8e80941Smrg_algebraic_pass_template = mako.template.Template("""
960b8e80941Smrg#include "nir.h"
961b8e80941Smrg#include "nir_builder.h"
962b8e80941Smrg#include "nir_search.h"
963b8e80941Smrg#include "nir_search_helpers.h"
964b8e80941Smrg
965b8e80941Smrg#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
966b8e80941Smrg#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
967b8e80941Smrg
968b8e80941Smrgstruct transform {
969b8e80941Smrg   const nir_search_expression *search;
970b8e80941Smrg   const nir_search_value *replace;
971b8e80941Smrg   unsigned condition_offset;
972b8e80941Smrg};
973b8e80941Smrg
974b8e80941Smrgstruct per_op_table {
975b8e80941Smrg   const uint16_t *filter;
976b8e80941Smrg   unsigned num_filtered_states;
977b8e80941Smrg   const uint16_t *table;
978b8e80941Smrg};
979b8e80941Smrg
980b8e80941Smrg/* Note: these must match the start states created in
981b8e80941Smrg * TreeAutomaton._build_table()
982b8e80941Smrg */
983b8e80941Smrg
984b8e80941Smrg/* WILDCARD_STATE = 0 is set by zeroing the state array */
985b8e80941Smrgstatic const uint16_t CONST_STATE = 1;
986b8e80941Smrg
987b8e80941Smrg#endif
988b8e80941Smrg
989b8e80941Smrg<% cache = {} %>
990b8e80941Smrg% for xform in xforms:
991b8e80941Smrg   ${xform.search.render(cache)}
992b8e80941Smrg   ${xform.replace.render(cache)}
993b8e80941Smrg% endfor
994b8e80941Smrg
995b8e80941Smrg% for state_id, state_xforms in enumerate(automaton.state_patterns):
996b8e80941Smrg% if state_xforms: # avoid emitting a 0-length array for MSVC
997b8e80941Smrgstatic const struct transform ${pass_name}_state${state_id}_xforms[] = {
998b8e80941Smrg% for i in state_xforms:
999b8e80941Smrg  { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1000b8e80941Smrg% endfor
1001b8e80941Smrg};
1002b8e80941Smrg% endif
1003b8e80941Smrg% endfor
1004b8e80941Smrg
1005b8e80941Smrgstatic const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1006b8e80941Smrg% for op in automaton.opcodes:
1007b8e80941Smrg   [${get_c_opcode(op)}] = {
1008b8e80941Smrg      .filter = (uint16_t []) {
1009b8e80941Smrg      % for e in automaton.filter[op]:
1010b8e80941Smrg         ${e},
1011b8e80941Smrg      % endfor
1012b8e80941Smrg      },
1013b8e80941Smrg      <%
1014b8e80941Smrg        num_filtered = len(automaton.rep[op])
1015b8e80941Smrg      %>
1016b8e80941Smrg      .num_filtered_states = ${num_filtered},
1017b8e80941Smrg      .table = (uint16_t []) {
1018b8e80941Smrg      <%
1019b8e80941Smrg        num_srcs = len(next(iter(automaton.table[op])))
1020b8e80941Smrg      %>
1021b8e80941Smrg      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1022b8e80941Smrg         ${automaton.table[op][indices]},
1023b8e80941Smrg      % endfor
1024b8e80941Smrg      },
1025b8e80941Smrg   },
1026b8e80941Smrg% endfor
1027b8e80941Smrg};
1028b8e80941Smrg
1029b8e80941Smrgstatic void
1030b8e80941Smrg${pass_name}_pre_block(nir_block *block, uint16_t *states)
1031b8e80941Smrg{
1032b8e80941Smrg   nir_foreach_instr(instr, block) {
1033b8e80941Smrg      switch (instr->type) {
1034b8e80941Smrg      case nir_instr_type_alu: {
1035b8e80941Smrg         nir_alu_instr *alu = nir_instr_as_alu(instr);
1036b8e80941Smrg         nir_op op = alu->op;
1037b8e80941Smrg         uint16_t search_op = nir_search_op_for_nir_op(op);
1038b8e80941Smrg         const struct per_op_table *tbl = &${pass_name}_table[search_op];
1039b8e80941Smrg         if (tbl->num_filtered_states == 0)
1040b8e80941Smrg            continue;
1041b8e80941Smrg
1042b8e80941Smrg         /* Calculate the index into the transition table. Note the index
1043b8e80941Smrg          * calculated must match the iteration order of Python's
1044b8e80941Smrg          * itertools.product(), which was used to emit the transition
1045b8e80941Smrg          * table.
1046b8e80941Smrg          */
1047b8e80941Smrg         uint16_t index = 0;
1048b8e80941Smrg         for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
1049b8e80941Smrg            index *= tbl->num_filtered_states;
1050b8e80941Smrg            index += tbl->filter[states[alu->src[i].src.ssa->index]];
1051b8e80941Smrg         }
1052b8e80941Smrg         states[alu->dest.dest.ssa.index] = tbl->table[index];
1053b8e80941Smrg         break;
1054b8e80941Smrg      }
1055b8e80941Smrg
1056b8e80941Smrg      case nir_instr_type_load_const: {
1057b8e80941Smrg         nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
1058b8e80941Smrg         states[load_const->def.index] = CONST_STATE;
1059b8e80941Smrg         break;
1060b8e80941Smrg      }
1061b8e80941Smrg
1062b8e80941Smrg      default:
1063b8e80941Smrg         break;
1064b8e80941Smrg      }
1065b8e80941Smrg   }
1066b8e80941Smrg}
1067b8e80941Smrg
1068b8e80941Smrgstatic bool
1069b8e80941Smrg${pass_name}_block(nir_builder *build, nir_block *block,
1070b8e80941Smrg                   const uint16_t *states, const bool *condition_flags)
1071b8e80941Smrg{
1072b8e80941Smrg   bool progress = false;
1073b8e80941Smrg
1074b8e80941Smrg   nir_foreach_instr_reverse_safe(instr, block) {
1075b8e80941Smrg      if (instr->type != nir_instr_type_alu)
1076b8e80941Smrg         continue;
1077b8e80941Smrg
1078b8e80941Smrg      nir_alu_instr *alu = nir_instr_as_alu(instr);
1079b8e80941Smrg      if (!alu->dest.dest.is_ssa)
1080b8e80941Smrg         continue;
1081b8e80941Smrg
1082b8e80941Smrg      switch (states[alu->dest.dest.ssa.index]) {
1083b8e80941Smrg% for i in range(len(automaton.state_patterns)):
1084b8e80941Smrg      case ${i}:
1085b8e80941Smrg         % if automaton.state_patterns[i]:
1086b8e80941Smrg         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
1087b8e80941Smrg            const struct transform *xform = &${pass_name}_state${i}_xforms[i];
1088b8e80941Smrg            if (condition_flags[xform->condition_offset] &&
1089b8e80941Smrg                nir_replace_instr(build, alu, xform->search, xform->replace)) {
1090b8e80941Smrg               progress = true;
1091b8e80941Smrg               break;
1092b8e80941Smrg            }
1093b8e80941Smrg         }
1094b8e80941Smrg         % endif
1095b8e80941Smrg         break;
1096b8e80941Smrg% endfor
1097b8e80941Smrg      default: assert(0);
1098b8e80941Smrg      }
1099b8e80941Smrg   }
1100b8e80941Smrg
1101b8e80941Smrg   return progress;
1102b8e80941Smrg}
1103b8e80941Smrg
1104b8e80941Smrgstatic bool
1105b8e80941Smrg${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
1106b8e80941Smrg{
1107b8e80941Smrg   bool progress = false;
1108b8e80941Smrg
1109b8e80941Smrg   nir_builder build;
1110b8e80941Smrg   nir_builder_init(&build, impl);
1111b8e80941Smrg
1112b8e80941Smrg   /* Note: it's important here that we're allocating a zeroed array, since
1113b8e80941Smrg    * state 0 is the default state, which means we don't have to visit
1114b8e80941Smrg    * anything other than constants and ALU instructions.
1115b8e80941Smrg    */
1116b8e80941Smrg   uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
1117b8e80941Smrg
1118b8e80941Smrg   nir_foreach_block(block, impl) {
1119b8e80941Smrg      ${pass_name}_pre_block(block, states);
1120b8e80941Smrg   }
1121b8e80941Smrg
1122b8e80941Smrg   nir_foreach_block_reverse(block, impl) {
1123b8e80941Smrg      progress |= ${pass_name}_block(&build, block, states, condition_flags);
1124b8e80941Smrg   }
1125b8e80941Smrg
1126b8e80941Smrg   free(states);
1127b8e80941Smrg
1128b8e80941Smrg   if (progress) {
1129b8e80941Smrg      nir_metadata_preserve(impl, nir_metadata_block_index |
1130b8e80941Smrg                                  nir_metadata_dominance);
1131b8e80941Smrg    } else {
1132b8e80941Smrg#ifndef NDEBUG
1133b8e80941Smrg      impl->valid_metadata &= ~nir_metadata_not_properly_reset;
1134b8e80941Smrg#endif
1135b8e80941Smrg    }
1136b8e80941Smrg
1137b8e80941Smrg   return progress;
1138b8e80941Smrg}
1139b8e80941Smrg
1140b8e80941Smrg
1141b8e80941Smrgbool
1142b8e80941Smrg${pass_name}(nir_shader *shader)
1143b8e80941Smrg{
1144b8e80941Smrg   bool progress = false;
1145b8e80941Smrg   bool condition_flags[${len(condition_list)}];
1146b8e80941Smrg   const nir_shader_compiler_options *options = shader->options;
1147b8e80941Smrg   const shader_info *info = &shader->info;
1148b8e80941Smrg   (void) options;
1149b8e80941Smrg   (void) info;
1150b8e80941Smrg
1151b8e80941Smrg   % for index, condition in enumerate(condition_list):
1152b8e80941Smrg   condition_flags[${index}] = ${condition};
1153b8e80941Smrg   % endfor
1154b8e80941Smrg
1155b8e80941Smrg   nir_foreach_function(function, shader) {
1156b8e80941Smrg      if (function->impl)
1157b8e80941Smrg         progress |= ${pass_name}_impl(function->impl, condition_flags);
1158b8e80941Smrg   }
1159b8e80941Smrg
1160b8e80941Smrg   return progress;
1161b8e80941Smrg}
1162b8e80941Smrg""")
1163b8e80941Smrg
1164b8e80941Smrg
1165b8e80941Smrg
1166b8e80941Smrgclass AlgebraicPass(object):
1167b8e80941Smrg   def __init__(self, pass_name, transforms):
1168b8e80941Smrg      self.xforms = []
1169b8e80941Smrg      self.opcode_xforms = defaultdict(lambda : [])
1170b8e80941Smrg      self.pass_name = pass_name
1171b8e80941Smrg
1172b8e80941Smrg      error = False
1173b8e80941Smrg
1174b8e80941Smrg      for xform in transforms:
1175b8e80941Smrg         if not isinstance(xform, SearchAndReplace):
1176b8e80941Smrg            try:
1177b8e80941Smrg               xform = SearchAndReplace(xform)
1178b8e80941Smrg            except:
1179b8e80941Smrg               print("Failed to parse transformation:", file=sys.stderr)
1180b8e80941Smrg               print("  " + str(xform), file=sys.stderr)
1181b8e80941Smrg               traceback.print_exc(file=sys.stderr)
1182b8e80941Smrg               print('', file=sys.stderr)
1183b8e80941Smrg               error = True
1184b8e80941Smrg               continue
1185b8e80941Smrg
1186b8e80941Smrg         self.xforms.append(xform)
1187b8e80941Smrg         if xform.search.opcode in conv_opcode_types:
1188b8e80941Smrg            dst_type = conv_opcode_types[xform.search.opcode]
1189b8e80941Smrg            for size in type_sizes(dst_type):
1190b8e80941Smrg               sized_opcode = xform.search.opcode + str(size)
1191b8e80941Smrg               self.opcode_xforms[sized_opcode].append(xform)
1192b8e80941Smrg         else:
1193b8e80941Smrg            self.opcode_xforms[xform.search.opcode].append(xform)
1194b8e80941Smrg
1195b8e80941Smrg      self.automaton = TreeAutomaton(self.xforms)
1196b8e80941Smrg
1197b8e80941Smrg      if error:
1198b8e80941Smrg         sys.exit(1)
1199b8e80941Smrg
1200b8e80941Smrg
1201b8e80941Smrg   def render(self):
1202b8e80941Smrg      return _algebraic_pass_template.render(pass_name=self.pass_name,
1203b8e80941Smrg                                             xforms=self.xforms,
1204b8e80941Smrg                                             opcode_xforms=self.opcode_xforms,
1205b8e80941Smrg                                             condition_list=condition_list,
1206b8e80941Smrg                                             automaton=self.automaton,
1207b8e80941Smrg                                             get_c_opcode=get_c_opcode,
1208b8e80941Smrg                                             itertools=itertools)
1209