101e04c3fSmrg#
201e04c3fSmrg# Copyright (C) 2014 Intel Corporation
301e04c3fSmrg#
401e04c3fSmrg# Permission is hereby granted, free of charge, to any person obtaining a
501e04c3fSmrg# copy of this software and associated documentation files (the "Software"),
601e04c3fSmrg# to deal in the Software without restriction, including without limitation
701e04c3fSmrg# the rights to use, copy, modify, merge, publish, distribute, sublicense,
801e04c3fSmrg# and/or sell copies of the Software, and to permit persons to whom the
901e04c3fSmrg# Software is furnished to do so, subject to the following conditions:
1001e04c3fSmrg#
1101e04c3fSmrg# The above copyright notice and this permission notice (including the next
1201e04c3fSmrg# paragraph) shall be included in all copies or substantial portions of the
1301e04c3fSmrg# Software.
1401e04c3fSmrg#
1501e04c3fSmrg# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1601e04c3fSmrg# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1701e04c3fSmrg# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
1801e04c3fSmrg# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1901e04c3fSmrg# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
2001e04c3fSmrg# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
2101e04c3fSmrg# IN THE SOFTWARE.
2201e04c3fSmrg#
2301e04c3fSmrg# Authors:
2401e04c3fSmrg#    Jason Ekstrand (jason@jlekstrand.net)
2501e04c3fSmrg
2601e04c3fSmrgimport ast
277e102996Smayafrom collections import defaultdict
2801e04c3fSmrgimport itertools
2901e04c3fSmrgimport struct
3001e04c3fSmrgimport sys
3101e04c3fSmrgimport mako.template
3201e04c3fSmrgimport re
3301e04c3fSmrgimport traceback
3401e04c3fSmrg
357e102996Smayafrom nir_opcodes import opcodes, type_sizes
367e102996Smaya
377ec681f3Smrg# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
387ec681f3Smrgnir_search_max_comm_ops = 8
397ec681f3Smrg
407e102996Smaya# These opcodes are only employed by nir_search.  This provides a mapping from
417e102996Smaya# opcode to destination type.
427e102996Smayaconv_opcode_types = {
437e102996Smaya    'i2f' : 'float',
447e102996Smaya    'u2f' : 'float',
457e102996Smaya    'f2f' : 'float',
467e102996Smaya    'f2u' : 'uint',
477e102996Smaya    'f2i' : 'int',
487e102996Smaya    'u2u' : 'uint',
497e102996Smaya    'i2i' : 'int',
507e102996Smaya    'b2f' : 'float',
517e102996Smaya    'b2i' : 'int',
527e102996Smaya    'i2b' : 'bool',
537e102996Smaya    'f2b' : 'bool',
547e102996Smaya}
557e102996Smaya
567e102996Smayadef get_c_opcode(op):
577e102996Smaya      if op in conv_opcode_types:
587e102996Smaya         return 'nir_search_op_' + op
597e102996Smaya      else:
607e102996Smaya         return 'nir_op_' + op
617e102996Smaya
6201e04c3fSmrg_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
6301e04c3fSmrg
6401e04c3fSmrgdef type_bits(type_str):
6501e04c3fSmrg   m = _type_re.match(type_str)
6601e04c3fSmrg   assert m.group('type')
6701e04c3fSmrg
6801e04c3fSmrg   if m.group('bits') is None:
6901e04c3fSmrg      return 0
7001e04c3fSmrg   else:
7101e04c3fSmrg      return int(m.group('bits'))
7201e04c3fSmrg
7301e04c3fSmrg# Represents a set of variables, each with a unique id
7401e04c3fSmrgclass VarSet(object):
7501e04c3fSmrg   def __init__(self):
7601e04c3fSmrg      self.names = {}
7701e04c3fSmrg      self.ids = itertools.count()
7801e04c3fSmrg      self.immutable = False;
7901e04c3fSmrg
8001e04c3fSmrg   def __getitem__(self, name):
8101e04c3fSmrg      if name not in self.names:
8201e04c3fSmrg         assert not self.immutable, "Unknown replacement variable: " + name
8301e04c3fSmrg         self.names[name] = next(self.ids)
8401e04c3fSmrg
8501e04c3fSmrg      return self.names[name]
8601e04c3fSmrg
8701e04c3fSmrg   def lock(self):
8801e04c3fSmrg      self.immutable = True
8901e04c3fSmrg
9001e04c3fSmrgclass Value(object):
9101e04c3fSmrg   @staticmethod
9201e04c3fSmrg   def create(val, name_base, varset):
9301e04c3fSmrg      if isinstance(val, bytes):
9401e04c3fSmrg         val = val.decode('utf-8')
9501e04c3fSmrg
9601e04c3fSmrg      if isinstance(val, tuple):
9701e04c3fSmrg         return Expression(val, name_base, varset)
9801e04c3fSmrg      elif isinstance(val, Expression):
9901e04c3fSmrg         return val
1007ec681f3Smrg      elif isinstance(val, str):
10101e04c3fSmrg         return Variable(val, name_base, varset)
1027ec681f3Smrg      elif isinstance(val, (bool, float, int)):
10301e04c3fSmrg         return Constant(val, name_base)
10401e04c3fSmrg
10501e04c3fSmrg   def __init__(self, val, name, type_str):
10601e04c3fSmrg      self.in_val = str(val)
10701e04c3fSmrg      self.name = name
10801e04c3fSmrg      self.type_str = type_str
10901e04c3fSmrg
11001e04c3fSmrg   def __str__(self):
11101e04c3fSmrg      return self.in_val
11201e04c3fSmrg
1137e102996Smaya   def get_bit_size(self):
1147e102996Smaya      """Get the physical bit-size that has been chosen for this value, or if
1157e102996Smaya      there is none, the canonical value which currently represents this
1167e102996Smaya      bit-size class. Variables will be preferred, i.e. if there are any
1177e102996Smaya      variables in the equivalence class, the canonical value will be a
1187e102996Smaya      variable. We do this since we'll need to know which variable each value
1197e102996Smaya      is equivalent to when constructing the replacement expression. This is
1207e102996Smaya      the "find" part of the union-find algorithm.
1217e102996Smaya      """
1227e102996Smaya      bit_size = self
1237e102996Smaya
1247e102996Smaya      while isinstance(bit_size, Value):
1257e102996Smaya         if bit_size._bit_size is None:
1267e102996Smaya            break
1277e102996Smaya         bit_size = bit_size._bit_size
1287e102996Smaya
1297e102996Smaya      if bit_size is not self:
1307e102996Smaya         self._bit_size = bit_size
1317e102996Smaya      return bit_size
1327e102996Smaya
1337e102996Smaya   def set_bit_size(self, other):
1347e102996Smaya      """Make self.get_bit_size() return what other.get_bit_size() return
1357e102996Smaya      before calling this, or just "other" if it's a concrete bit-size. This is
1367e102996Smaya      the "union" part of the union-find algorithm.
1377e102996Smaya      """
1387e102996Smaya
1397e102996Smaya      self_bit_size = self.get_bit_size()
1407e102996Smaya      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
1417e102996Smaya
1427e102996Smaya      if self_bit_size == other_bit_size:
1437e102996Smaya         return
1447e102996Smaya
1457e102996Smaya      self_bit_size._bit_size = other_bit_size
1467e102996Smaya
14701e04c3fSmrg   @property
14801e04c3fSmrg   def type_enum(self):
14901e04c3fSmrg      return "nir_search_value_" + self.type_str
15001e04c3fSmrg
15101e04c3fSmrg   @property
15201e04c3fSmrg   def c_type(self):
15301e04c3fSmrg      return "nir_search_" + self.type_str
15401e04c3fSmrg
1557e102996Smaya   def __c_name(self, cache):
1567e102996Smaya      if cache is not None and self.name in cache:
1577e102996Smaya         return cache[self.name]
1587e102996Smaya      else:
1597e102996Smaya         return self.name
1607e102996Smaya
1617e102996Smaya   def c_value_ptr(self, cache):
1627e102996Smaya      return "&{0}.value".format(self.__c_name(cache))
1637e102996Smaya
1647e102996Smaya   def c_ptr(self, cache):
1657e102996Smaya      return "&{0}".format(self.__c_name(cache))
1667e102996Smaya
16701e04c3fSmrg   @property
1687e102996Smaya   def c_bit_size(self):
1697e102996Smaya      bit_size = self.get_bit_size()
1707e102996Smaya      if isinstance(bit_size, int):
1717e102996Smaya         return bit_size
1727e102996Smaya      elif isinstance(bit_size, Variable):
1737e102996Smaya         return -bit_size.index - 1
1747e102996Smaya      else:
1757e102996Smaya         # If the bit-size class is neither a variable, nor an actual bit-size, then
1767e102996Smaya         # - If it's in the search expression, we don't need to check anything
1777e102996Smaya         # - If it's in the replace expression, either it's ambiguous (in which
1787e102996Smaya         # case we'd reject it), or it equals the bit-size of the search value
1797e102996Smaya         # We represent these cases with a 0 bit-size.
1807e102996Smaya         return 0
1817e102996Smaya
1827e102996Smaya   __template = mako.template.Template("""{
1837e102996Smaya   { ${val.type_enum}, ${val.c_bit_size} },
1847e102996Smaya% if isinstance(val, Constant):
1857e102996Smaya   ${val.type()}, { ${val.hex()} /* ${val.value} */ },
1867e102996Smaya% elif isinstance(val, Variable):
1877e102996Smaya   ${val.index}, /* ${val.var_name} */
1887e102996Smaya   ${'true' if val.is_constant else 'false'},
1897e102996Smaya   ${val.type() or 'nir_type_invalid' },
1907e102996Smaya   ${val.cond if val.cond else 'NULL'},
1917ec681f3Smrg   ${val.swizzle()},
1927e102996Smaya% elif isinstance(val, Expression):
1937ec681f3Smrg   ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
1947e102996Smaya   ${val.comm_expr_idx}, ${val.comm_exprs},
1957e102996Smaya   ${val.c_opcode()},
1967e102996Smaya   { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
1977e102996Smaya   ${val.cond if val.cond else 'NULL'},
1987e102996Smaya% endif
1997e102996Smaya};""")
20001e04c3fSmrg
2017e102996Smaya   def render(self, cache):
2027e102996Smaya      struct_init = self.__template.render(val=self, cache=cache,
2037e102996Smaya                                           Constant=Constant,
2047e102996Smaya                                           Variable=Variable,
2057e102996Smaya                                           Expression=Expression)
2067e102996Smaya      if cache is not None and struct_init in cache:
2077e102996Smaya         # If it's in the cache, register a name remap in the cache and render
2087e102996Smaya         # only a comment saying it's been remapped
2097e102996Smaya         cache[self.name] = cache[struct_init]
2107e102996Smaya         return "/* {} -> {} in the cache */\n".format(self.name,
2117e102996Smaya                                                       cache[struct_init])
2127e102996Smaya      else:
2137e102996Smaya         if cache is not None:
2147e102996Smaya            cache[struct_init] = self.name
2157e102996Smaya         return "static const {} {} = {}\n".format(self.c_type, self.name,
2167e102996Smaya                                                   struct_init)
21701e04c3fSmrg
21801e04c3fSmrg_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
21901e04c3fSmrg
22001e04c3fSmrgclass Constant(Value):
22101e04c3fSmrg   def __init__(self, val, name):
22201e04c3fSmrg      Value.__init__(self, val, name, "constant")
22301e04c3fSmrg
22401e04c3fSmrg      if isinstance(val, (str)):
22501e04c3fSmrg         m = _constant_re.match(val)
22601e04c3fSmrg         self.value = ast.literal_eval(m.group('value'))
2277e102996Smaya         self._bit_size = int(m.group('bits')) if m.group('bits') else None
22801e04c3fSmrg      else:
22901e04c3fSmrg         self.value = val
2307e102996Smaya         self._bit_size = None
23101e04c3fSmrg
23201e04c3fSmrg      if isinstance(self.value, bool):
2337e102996Smaya         assert self._bit_size is None or self._bit_size == 1
2347e102996Smaya         self._bit_size = 1
23501e04c3fSmrg
23601e04c3fSmrg   def hex(self):
23701e04c3fSmrg      if isinstance(self.value, (bool)):
23801e04c3fSmrg         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
2397ec681f3Smrg      if isinstance(self.value, int):
24001e04c3fSmrg         return hex(self.value)
24101e04c3fSmrg      elif isinstance(self.value, float):
2427ec681f3Smrg         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
24301e04c3fSmrg      else:
24401e04c3fSmrg         assert False
24501e04c3fSmrg
24601e04c3fSmrg   def type(self):
24701e04c3fSmrg      if isinstance(self.value, (bool)):
24801e04c3fSmrg         return "nir_type_bool"
2497ec681f3Smrg      elif isinstance(self.value, int):
25001e04c3fSmrg         return "nir_type_int"
25101e04c3fSmrg      elif isinstance(self.value, float):
25201e04c3fSmrg         return "nir_type_float"
25301e04c3fSmrg
2547ec681f3Smrg   def equivalent(self, other):
2557ec681f3Smrg      """Check that two constants are equivalent.
2567ec681f3Smrg
2577ec681f3Smrg      This is check is much weaker than equality.  One generally cannot be
2587ec681f3Smrg      used in place of the other.  Using this implementation for the __eq__
2597ec681f3Smrg      will break BitSizeValidator.
2607ec681f3Smrg
2617ec681f3Smrg      """
2627ec681f3Smrg      if not isinstance(other, type(self)):
2637ec681f3Smrg         return False
2647ec681f3Smrg
2657ec681f3Smrg      return self.value == other.value
2667ec681f3Smrg
2677ec681f3Smrg# The $ at the end forces there to be an error if any part of the string
2687ec681f3Smrg# doesn't match one of the field patterns.
26901e04c3fSmrg_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
27001e04c3fSmrg                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
2717ec681f3Smrg                          r"(?P<cond>\([^\)]+\))?"
2727ec681f3Smrg                          r"(?P<swiz>\.[xyzw]+)?"
2737ec681f3Smrg                          r"$")
27401e04c3fSmrg
27501e04c3fSmrgclass Variable(Value):
27601e04c3fSmrg   def __init__(self, val, name, varset):
27701e04c3fSmrg      Value.__init__(self, val, name, "variable")
27801e04c3fSmrg
27901e04c3fSmrg      m = _var_name_re.match(val)
2807ec681f3Smrg      assert m and m.group('name') is not None, \
2817ec681f3Smrg            "Malformed variable name \"{}\".".format(val)
28201e04c3fSmrg
28301e04c3fSmrg      self.var_name = m.group('name')
2847e102996Smaya
2857e102996Smaya      # Prevent common cases where someone puts quotes around a literal
2867e102996Smaya      # constant.  If we want to support names that have numeric or
2877e102996Smaya      # punctuation characters, we can me the first assertion more flexible.
2887e102996Smaya      assert self.var_name.isalpha()
2897ec681f3Smrg      assert self.var_name != 'True'
2907ec681f3Smrg      assert self.var_name != 'False'
2917e102996Smaya
29201e04c3fSmrg      self.is_constant = m.group('const') is not None
29301e04c3fSmrg      self.cond = m.group('cond')
29401e04c3fSmrg      self.required_type = m.group('type')
2957e102996Smaya      self._bit_size = int(m.group('bits')) if m.group('bits') else None
2967ec681f3Smrg      self.swiz = m.group('swiz')
29701e04c3fSmrg
29801e04c3fSmrg      if self.required_type == 'bool':
2997e102996Smaya         if self._bit_size is not None:
3007e102996Smaya            assert self._bit_size in type_sizes(self.required_type)
3017e102996Smaya         else:
3027e102996Smaya            self._bit_size = 1
30301e04c3fSmrg
30401e04c3fSmrg      if self.required_type is not None:
30501e04c3fSmrg         assert self.required_type in ('float', 'bool', 'int', 'uint')
30601e04c3fSmrg
30701e04c3fSmrg      self.index = varset[self.var_name]
30801e04c3fSmrg
30901e04c3fSmrg   def type(self):
31001e04c3fSmrg      if self.required_type == 'bool':
31101e04c3fSmrg         return "nir_type_bool"
31201e04c3fSmrg      elif self.required_type in ('int', 'uint'):
31301e04c3fSmrg         return "nir_type_int"
31401e04c3fSmrg      elif self.required_type == 'float':
31501e04c3fSmrg         return "nir_type_float"
31601e04c3fSmrg
3177ec681f3Smrg   def equivalent(self, other):
3187ec681f3Smrg      """Check that two variables are equivalent.
3197ec681f3Smrg
3207ec681f3Smrg      This is check is much weaker than equality.  One generally cannot be
3217ec681f3Smrg      used in place of the other.  Using this implementation for the __eq__
3227ec681f3Smrg      will break BitSizeValidator.
3237ec681f3Smrg
3247ec681f3Smrg      """
3257ec681f3Smrg      if not isinstance(other, type(self)):
3267ec681f3Smrg         return False
3277ec681f3Smrg
3287ec681f3Smrg      return self.index == other.index
3297ec681f3Smrg
3307ec681f3Smrg   def swizzle(self):
3317ec681f3Smrg      if self.swiz is not None:
3327ec681f3Smrg         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
3337ec681f3Smrg                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
3347ec681f3Smrg                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
3357ec681f3Smrg                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
3367ec681f3Smrg                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
3377ec681f3Smrg         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
3387ec681f3Smrg      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
3397ec681f3Smrg
3407ec681f3Smrg_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
34101e04c3fSmrg                        r"(?P<cond>\([^\)]+\))?")
34201e04c3fSmrg
34301e04c3fSmrgclass Expression(Value):
34401e04c3fSmrg   def __init__(self, expr, name_base, varset):
34501e04c3fSmrg      Value.__init__(self, expr, name_base, "expression")
34601e04c3fSmrg      assert isinstance(expr, tuple)
34701e04c3fSmrg
34801e04c3fSmrg      m = _opcode_re.match(expr[0])
34901e04c3fSmrg      assert m and m.group('opcode') is not None
35001e04c3fSmrg
35101e04c3fSmrg      self.opcode = m.group('opcode')
3527e102996Smaya      self._bit_size = int(m.group('bits')) if m.group('bits') else None
35301e04c3fSmrg      self.inexact = m.group('inexact') is not None
3547ec681f3Smrg      self.exact = m.group('exact') is not None
35501e04c3fSmrg      self.cond = m.group('cond')
3567ec681f3Smrg
3577ec681f3Smrg      assert not self.inexact or not self.exact, \
3587ec681f3Smrg            'Expression cannot be both exact and inexact.'
3597ec681f3Smrg
3607ec681f3Smrg      # "many-comm-expr" isn't really a condition.  It's notification to the
3617ec681f3Smrg      # generator that this pattern is known to have too many commutative
3627ec681f3Smrg      # expressions, and an error should not be generated for this case.
3637ec681f3Smrg      self.many_commutative_expressions = False
3647ec681f3Smrg      if self.cond and self.cond.find("many-comm-expr") >= 0:
3657ec681f3Smrg         # Split the condition into a comma-separated list.  Remove
3667ec681f3Smrg         # "many-comm-expr".  If there is anything left, put it back together.
3677ec681f3Smrg         c = self.cond[1:-1].split(",")
3687ec681f3Smrg         c.remove("many-comm-expr")
3697ec681f3Smrg
3707ec681f3Smrg         self.cond = "({})".format(",".join(c)) if c else None
3717ec681f3Smrg         self.many_commutative_expressions = True
3727ec681f3Smrg
37301e04c3fSmrg      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
37401e04c3fSmrg                       for (i, src) in enumerate(expr[1:]) ]
37501e04c3fSmrg
3767ec681f3Smrg      # nir_search_expression::srcs is hard-coded to 4
3777ec681f3Smrg      assert len(self.sources) <= 4
3787ec681f3Smrg
3797e102996Smaya      if self.opcode in conv_opcode_types:
3807e102996Smaya         assert self._bit_size is None, \
3817e102996Smaya                'Expression cannot use an unsized conversion opcode with ' \
3827e102996Smaya                'an explicit size; that\'s silly.'
3837e102996Smaya
3847e102996Smaya      self.__index_comm_exprs(0)
3857e102996Smaya
3867ec681f3Smrg   def equivalent(self, other):
3877ec681f3Smrg      """Check that two variables are equivalent.
3887ec681f3Smrg
3897ec681f3Smrg      This is check is much weaker than equality.  One generally cannot be
3907ec681f3Smrg      used in place of the other.  Using this implementation for the __eq__
3917ec681f3Smrg      will break BitSizeValidator.
3927ec681f3Smrg
3937ec681f3Smrg      This implementation does not check for equivalence due to commutativity,
3947ec681f3Smrg      but it could.
3957ec681f3Smrg
3967ec681f3Smrg      """
3977ec681f3Smrg      if not isinstance(other, type(self)):
3987ec681f3Smrg         return False
3997ec681f3Smrg
4007ec681f3Smrg      if len(self.sources) != len(other.sources):
4017ec681f3Smrg         return False
4027ec681f3Smrg
4037ec681f3Smrg      if self.opcode != other.opcode:
4047ec681f3Smrg         return False
4057ec681f3Smrg
4067ec681f3Smrg      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
4077ec681f3Smrg
4087e102996Smaya   def __index_comm_exprs(self, base_idx):
4097e102996Smaya      """Recursively count and index commutative expressions
4107e102996Smaya      """
4117e102996Smaya      self.comm_exprs = 0
4127ec681f3Smrg
4137ec681f3Smrg      # A note about the explicit "len(self.sources)" check. The list of
4147ec681f3Smrg      # sources comes from user input, and that input might be bad.  Check
4157ec681f3Smrg      # that the expected second source exists before accessing it. Without
4167ec681f3Smrg      # this check, a unit test that does "('iadd', 'a')" will crash.
4177e102996Smaya      if self.opcode not in conv_opcode_types and \
4187ec681f3Smrg         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
4197ec681f3Smrg         len(self.sources) >= 2 and \
4207ec681f3Smrg         not self.sources[0].equivalent(self.sources[1]):
4217e102996Smaya         self.comm_expr_idx = base_idx
4227e102996Smaya         self.comm_exprs += 1
42301e04c3fSmrg      else:
4247e102996Smaya         self.comm_expr_idx = -1
42501e04c3fSmrg
4267e102996Smaya      for s in self.sources:
4277e102996Smaya         if isinstance(s, Expression):
4287e102996Smaya            s.__index_comm_exprs(base_idx + self.comm_exprs)
4297e102996Smaya            self.comm_exprs += s.comm_exprs
43001e04c3fSmrg
4317e102996Smaya      return self.comm_exprs
43201e04c3fSmrg
4337e102996Smaya   def c_opcode(self):
4347e102996Smaya      return get_c_opcode(self.opcode)
4357e102996Smaya
4367e102996Smaya   def render(self, cache):
4377e102996Smaya      srcs = "\n".join(src.render(cache) for src in self.sources)
4387e102996Smaya      return srcs + super(Expression, self).render(cache)
43901e04c3fSmrg
44001e04c3fSmrgclass BitSizeValidator(object):
44101e04c3fSmrg   """A class for validating bit sizes of expressions.
44201e04c3fSmrg
44301e04c3fSmrg   NIR supports multiple bit-sizes on expressions in order to handle things
44401e04c3fSmrg   such as fp64.  The source and destination of every ALU operation is
44501e04c3fSmrg   assigned a type and that type may or may not specify a bit size.  Sources
44601e04c3fSmrg   and destinations whose type does not specify a bit size are considered
44701e04c3fSmrg   "unsized" and automatically take on the bit size of the corresponding
44801e04c3fSmrg   register or SSA value.  NIR has two simple rules for bit sizes that are
44901e04c3fSmrg   validated by nir_validator:
45001e04c3fSmrg
45101e04c3fSmrg    1) A given SSA def or register has a single bit size that is respected by
45201e04c3fSmrg       everything that reads from it or writes to it.
45301e04c3fSmrg
45401e04c3fSmrg    2) The bit sizes of all unsized inputs/outputs on any given ALU
45501e04c3fSmrg       instruction must match.  They need not match the sized inputs or
45601e04c3fSmrg       outputs but they must match each other.
45701e04c3fSmrg
45801e04c3fSmrg   In order to keep nir_algebraic relatively simple and easy-to-use,
45901e04c3fSmrg   nir_search supports a type of bit-size inference based on the two rules
46001e04c3fSmrg   above.  This is similar to type inference in many common programming
46101e04c3fSmrg   languages.  If, for instance, you are constructing an add operation and you
46201e04c3fSmrg   know the second source is 16-bit, then you know that the other source and
46301e04c3fSmrg   the destination must also be 16-bit.  There are, however, cases where this
46401e04c3fSmrg   inference can be ambiguous or contradictory.  Consider, for instance, the
46501e04c3fSmrg   following transformation:
46601e04c3fSmrg
4677e102996Smaya   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
46801e04c3fSmrg
46901e04c3fSmrg   This transformation can potentially cause a problem because usub_borrow is
47001e04c3fSmrg   well-defined for any bit-size of integer.  However, b2i always generates a
47101e04c3fSmrg   32-bit result so it could end up replacing a 64-bit expression with one
47201e04c3fSmrg   that takes two 64-bit values and produces a 32-bit value.  As another
47301e04c3fSmrg   example, consider this expression:
47401e04c3fSmrg
47501e04c3fSmrg   (('bcsel', a, b, 0), ('iand', a, b))
47601e04c3fSmrg
47701e04c3fSmrg   In this case, in the search expression a must be 32-bit but b can
47801e04c3fSmrg   potentially have any bit size.  If we had a 64-bit b value, we would end up
47901e04c3fSmrg   trying to and a 32-bit value with a 64-bit value which would be invalid
48001e04c3fSmrg
48101e04c3fSmrg   This class solves that problem by providing a validation layer that proves
48201e04c3fSmrg   that a given search-and-replace operation is 100% well-defined before we
48301e04c3fSmrg   generate any code.  This ensures that bugs are caught at compile time
48401e04c3fSmrg   rather than at run time.
48501e04c3fSmrg
4867e102996Smaya   Each value maintains a "bit-size class", which is either an actual bit size
4877e102996Smaya   or an equivalence class with other values that must have the same bit size.
4887e102996Smaya   The validator works by combining bit-size classes with each other according
4897e102996Smaya   to the NIR rules outlined above, checking that there are no inconsistencies.
4907e102996Smaya   When doing this for the replacement expression, we make sure to never change
4917e102996Smaya   the equivalence class of any of the search values. We could make the example
4927e102996Smaya   transforms above work by doing some extra run-time checking of the search
4937e102996Smaya   expression, but we make the user specify those constraints themselves, to
4947e102996Smaya   avoid any surprises. Since the replacement bitsizes can only be connected to
4957e102996Smaya   the source bitsize via variables (variables must have the same bitsize in
4967e102996Smaya   the source and replacment expressions) or the roots of the expression (the
4977e102996Smaya   replacement expression must produce the same bit size as the search
4987e102996Smaya   expression), we prevent merging a variable with anything when processing the
4997e102996Smaya   replacement expression, or specializing the search bitsize
5007e102996Smaya   with anything. The former prevents
50101e04c3fSmrg
5027e102996Smaya   (('bcsel', a, b, 0), ('iand', a, b))
50301e04c3fSmrg
5047e102996Smaya   from being allowed, since we'd have to merge the bitsizes for a and b due to
5057e102996Smaya   the 'iand', while the latter prevents
50601e04c3fSmrg
5077e102996Smaya   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
50801e04c3fSmrg
5097e102996Smaya   from being allowed, since the search expression has the bit size of a and b,
5107e102996Smaya   which can't be specialized to 32 which is the bitsize of the replace
5117e102996Smaya   expression. It also prevents something like:
51201e04c3fSmrg
5137e102996Smaya   (('b2i', ('i2b', a)), ('ineq', a, 0))
51401e04c3fSmrg
5157e102996Smaya   since the bitsize of 'b2i', which can be anything, can't be specialized to
5167e102996Smaya   the bitsize of a.
51701e04c3fSmrg
5187e102996Smaya   After doing all this, we check that every subexpression of the replacement
5197e102996Smaya   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
5207e102996Smaya   of the search expresssion, since those are the things that are known when
5217e102996Smaya   constructing the replacement expresssion. Finally, we record the bitsize
5227e102996Smaya   needed in nir_search_value so that we know what to do when building the
5237e102996Smaya   replacement expression.
5247e102996Smaya   """
52501e04c3fSmrg
5267e102996Smaya   def __init__(self, varset):
5277e102996Smaya      self._var_classes = [None] * len(varset.names)
5287e102996Smaya
5297e102996Smaya   def compare_bitsizes(self, a, b):
5307e102996Smaya      """Determines which bitsize class is a specialization of the other, or
5317e102996Smaya      whether neither is. When we merge two different bitsizes, the
5327e102996Smaya      less-specialized bitsize always points to the more-specialized one, so
5337e102996Smaya      that calling get_bit_size() always gets you the most specialized bitsize.
5347e102996Smaya      The specialization partial order is given by:
5357e102996Smaya      - Physical bitsizes are always the most specialized, and a different
5367e102996Smaya        bitsize can never specialize another.
5377e102996Smaya      - In the search expression, variables can always be specialized to each
5387e102996Smaya        other and to physical bitsizes. In the replace expression, we disallow
5397e102996Smaya        this to avoid adding extra constraints to the search expression that
5407e102996Smaya        the user didn't specify.
5417e102996Smaya      - Expressions and constants without a bitsize can always be specialized to
5427e102996Smaya        each other and variables, but not the other way around.
5437e102996Smaya
5447e102996Smaya        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
5457e102996Smaya        and None if they are not comparable (neither a <= b nor b <= a).
5467e102996Smaya      """
5477e102996Smaya      if isinstance(a, int):
5487e102996Smaya         if isinstance(b, int):
5497e102996Smaya            return 0 if a == b else None
5507e102996Smaya         elif isinstance(b, Variable):
5517e102996Smaya            return -1 if self.is_search else None
5527e102996Smaya         else:
5537e102996Smaya            return -1
5547e102996Smaya      elif isinstance(a, Variable):
5557e102996Smaya         if isinstance(b, int):
5567e102996Smaya            return 1 if self.is_search else None
5577e102996Smaya         elif isinstance(b, Variable):
5587e102996Smaya            return 0 if self.is_search or a.index == b.index else None
5597e102996Smaya         else:
5607e102996Smaya            return -1
5617e102996Smaya      else:
5627e102996Smaya         if isinstance(b, int):
5637e102996Smaya            return 1
5647e102996Smaya         elif isinstance(b, Variable):
5657e102996Smaya            return 1
5667e102996Smaya         else:
5677e102996Smaya            return 0
5687e102996Smaya
5697e102996Smaya   def unify_bit_size(self, a, b, error_msg):
5707e102996Smaya      """Record that a must have the same bit-size as b. If both
5717e102996Smaya      have been assigned conflicting physical bit-sizes, call "error_msg" with
5727e102996Smaya      the bit-sizes of self and other to get a message and raise an error.
5737e102996Smaya      In the replace expression, disallow merging variables with other
5747e102996Smaya      variables and physical bit-sizes as well.
5757e102996Smaya      """
5767e102996Smaya      a_bit_size = a.get_bit_size()
5777e102996Smaya      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
5787e102996Smaya
5797e102996Smaya      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
5807e102996Smaya
5817e102996Smaya      assert cmp_result is not None, \
5827e102996Smaya         error_msg(a_bit_size, b_bit_size)
5837e102996Smaya
5847e102996Smaya      if cmp_result < 0:
5857e102996Smaya         b_bit_size.set_bit_size(a)
5867e102996Smaya      elif not isinstance(a_bit_size, int):
5877e102996Smaya         a_bit_size.set_bit_size(b)
5887e102996Smaya
5897e102996Smaya   def merge_variables(self, val):
5907e102996Smaya      """Perform the first part of type inference by merging all the different
5917e102996Smaya      uses of the same variable. We always do this as if we're in the search
5927e102996Smaya      expression, even if we're actually not, since otherwise we'd get errors
5937e102996Smaya      if the search expression specified some constraint but the replace
5947e102996Smaya      expression didn't, because we'd be merging a variable and a constant.
5957e102996Smaya      """
5967e102996Smaya      if isinstance(val, Variable):
5977e102996Smaya         if self._var_classes[val.index] is None:
5987e102996Smaya            self._var_classes[val.index] = val
5997e102996Smaya         else:
6007e102996Smaya            other = self._var_classes[val.index]
6017e102996Smaya            self.unify_bit_size(other, val,
6027e102996Smaya                  lambda other_bit_size, bit_size:
6037e102996Smaya                     'Variable {} has conflicting bit size requirements: ' \
6047e102996Smaya                     'it must have bit size {} and {}'.format(
6057e102996Smaya                        val.var_name, other_bit_size, bit_size))
60601e04c3fSmrg      elif isinstance(val, Expression):
6077e102996Smaya         for src in val.sources:
6087e102996Smaya            self.merge_variables(src)
6097e102996Smaya
6107e102996Smaya   def validate_value(self, val):
6117e102996Smaya      """Validate the an expression by performing classic Hindley-Milner
6127e102996Smaya      type inference on bitsizes. This will detect if there are any conflicting
6137e102996Smaya      requirements, and unify variables so that we know which variables must
6147e102996Smaya      have the same bitsize. If we're operating on the replace expression, we
6157e102996Smaya      will refuse to merge different variables together or merge a variable
6167e102996Smaya      with a constant, in order to prevent surprises due to rules unexpectedly
6177e102996Smaya      not matching at runtime.
6187e102996Smaya      """
6197e102996Smaya      if not isinstance(val, Expression):
6207e102996Smaya         return
6217e102996Smaya
6227e102996Smaya      # Generic conversion ops are special in that they have a single unsized
6237e102996Smaya      # source and an unsized destination and the two don't have to match.
6247e102996Smaya      # This means there's no validation or unioning to do here besides the
6257e102996Smaya      # len(val.sources) check.
6267e102996Smaya      if val.opcode in conv_opcode_types:
6277e102996Smaya         assert len(val.sources) == 1, \
6287e102996Smaya            "Expression {} has {} sources, expected 1".format(
6297e102996Smaya               val, len(val.sources))
6307e102996Smaya         self.validate_value(val.sources[0])
6317e102996Smaya         return
6327e102996Smaya
6337e102996Smaya      nir_op = opcodes[val.opcode]
6347e102996Smaya      assert len(val.sources) == nir_op.num_inputs, \
6357e102996Smaya         "Expression {} has {} sources, expected {}".format(
6367e102996Smaya            val, len(val.sources), nir_op.num_inputs)
6377e102996Smaya
6387e102996Smaya      for src in val.sources:
6397e102996Smaya         self.validate_value(src)
6407e102996Smaya
6417e102996Smaya      dst_type_bits = type_bits(nir_op.output_type)
6427e102996Smaya
6437e102996Smaya      # First, unify all the sources. That way, an error coming up because two
6447e102996Smaya      # sources have an incompatible bit-size won't produce an error message
6457e102996Smaya      # involving the destination.
6467e102996Smaya      first_unsized_src = None
6477e102996Smaya      for src_type, src in zip(nir_op.input_types, val.sources):
6487e102996Smaya         src_type_bits = type_bits(src_type)
6497e102996Smaya         if src_type_bits == 0:
6507e102996Smaya            if first_unsized_src is None:
6517e102996Smaya               first_unsized_src = src
65201e04c3fSmrg               continue
65301e04c3fSmrg
6547e102996Smaya            if self.is_search:
6557e102996Smaya               self.unify_bit_size(first_unsized_src, src,
6567e102996Smaya                  lambda first_unsized_src_bit_size, src_bit_size:
6577e102996Smaya                     'Source {} of {} must have bit size {}, while source {} ' \
6587e102996Smaya                     'must have incompatible bit size {}'.format(
6597e102996Smaya                        first_unsized_src, val, first_unsized_src_bit_size,
6607e102996Smaya                        src, src_bit_size))
66101e04c3fSmrg            else:
6627e102996Smaya               self.unify_bit_size(first_unsized_src, src,
6637e102996Smaya                  lambda first_unsized_src_bit_size, src_bit_size:
6647e102996Smaya                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
6657e102996Smaya                     'of {} may not have the same bit size when building the ' \
6667e102996Smaya                     'replacement expression.'.format(
6677e102996Smaya                        first_unsized_src, first_unsized_src_bit_size, src,
6687e102996Smaya                        src_bit_size, val))
66901e04c3fSmrg         else:
6707e102996Smaya            if self.is_search:
6717e102996Smaya               self.unify_bit_size(src, src_type_bits,
6727e102996Smaya                  lambda src_bit_size, unused:
6737e102996Smaya                     '{} must have {} bits, but as a source of nir_op_{} '\
6747e102996Smaya                     'it must have {} bits'.format(
6757e102996Smaya                        src, src_bit_size, nir_op.name, src_type_bits))
6767e102996Smaya            else:
6777e102996Smaya               self.unify_bit_size(src, src_type_bits,
6787e102996Smaya                  lambda src_bit_size, unused:
6797e102996Smaya                     '{} has the bit size of {}, but as a source of ' \
6807e102996Smaya                     'nir_op_{} it must have {} bits, which may not be the ' \
6817e102996Smaya                     'same'.format(
6827e102996Smaya                        src, src_bit_size, nir_op.name, src_type_bits))
6837e102996Smaya
6847e102996Smaya      if dst_type_bits == 0:
6857e102996Smaya         if first_unsized_src is not None:
6867e102996Smaya            if self.is_search:
6877e102996Smaya               self.unify_bit_size(val, first_unsized_src,
6887e102996Smaya                  lambda val_bit_size, src_bit_size:
6897e102996Smaya                     '{} must have the bit size of {}, while its source {} ' \
6907e102996Smaya                     'must have incompatible bit size {}'.format(
6917e102996Smaya                        val, val_bit_size, first_unsized_src, src_bit_size))
69201e04c3fSmrg            else:
6937e102996Smaya               self.unify_bit_size(val, first_unsized_src,
6947e102996Smaya                  lambda val_bit_size, src_bit_size:
6957e102996Smaya                     '{} must have {} bits, but its source {} ' \
6967e102996Smaya                     '(bit size of {}) may not have that bit size ' \
6977e102996Smaya                     'when building the replacement.'.format(
6987e102996Smaya                        val, val_bit_size, first_unsized_src, src_bit_size))
6997e102996Smaya      else:
7007e102996Smaya         self.unify_bit_size(val, dst_type_bits,
7017e102996Smaya            lambda dst_bit_size, unused:
7027e102996Smaya               '{} must have {} bits, but as a destination of nir_op_{} ' \
7037e102996Smaya               'it must have {} bits'.format(
7047e102996Smaya                  val, dst_bit_size, nir_op.name, dst_type_bits))
7057e102996Smaya
7067e102996Smaya   def validate_replace(self, val, search):
7077e102996Smaya      bit_size = val.get_bit_size()
7087e102996Smaya      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
7097e102996Smaya            bit_size == search.get_bit_size(), \
7107e102996Smaya            'Ambiguous bit size for replacement value {}: ' \
7117e102996Smaya            'it cannot be deduced from a variable, a fixed bit size ' \
7127e102996Smaya            'somewhere, or the search expression.'.format(val)
7137e102996Smaya
7147e102996Smaya      if isinstance(val, Expression):
7157e102996Smaya         for src in val.sources:
7167e102996Smaya            self.validate_replace(src, search)
7177e102996Smaya
7187e102996Smaya   def validate(self, search, replace):
7197e102996Smaya      self.is_search = True
7207e102996Smaya      self.merge_variables(search)
7217e102996Smaya      self.merge_variables(replace)
7227e102996Smaya      self.validate_value(search)
72301e04c3fSmrg
7247e102996Smaya      self.is_search = False
7257e102996Smaya      self.validate_value(replace)
72601e04c3fSmrg
7277e102996Smaya      # Check that search is always more specialized than replace. Note that
7287e102996Smaya      # we're doing this in replace mode, disallowing merging variables.
7297e102996Smaya      search_bit_size = search.get_bit_size()
7307e102996Smaya      replace_bit_size = replace.get_bit_size()
7317e102996Smaya      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
73201e04c3fSmrg
7337e102996Smaya      assert cmp_result is not None and cmp_result <= 0, \
7347e102996Smaya         'The search expression bit size {} and replace expression ' \
7357e102996Smaya         'bit size {} may not be the same'.format(
7367e102996Smaya               search_bit_size, replace_bit_size)
73701e04c3fSmrg
7387e102996Smaya      replace.set_bit_size(search)
73901e04c3fSmrg
7407e102996Smaya      self.validate_replace(replace, search)
74101e04c3fSmrg
74201e04c3fSmrg_optimization_ids = itertools.count()
74301e04c3fSmrg
74401e04c3fSmrgcondition_list = ['true']
74501e04c3fSmrg
74601e04c3fSmrgclass SearchAndReplace(object):
74701e04c3fSmrg   def __init__(self, transform):
74801e04c3fSmrg      self.id = next(_optimization_ids)
74901e04c3fSmrg
75001e04c3fSmrg      search = transform[0]
75101e04c3fSmrg      replace = transform[1]
75201e04c3fSmrg      if len(transform) > 2:
75301e04c3fSmrg         self.condition = transform[2]
75401e04c3fSmrg      else:
75501e04c3fSmrg         self.condition = 'true'
75601e04c3fSmrg
75701e04c3fSmrg      if self.condition not in condition_list:
75801e04c3fSmrg         condition_list.append(self.condition)
75901e04c3fSmrg      self.condition_index = condition_list.index(self.condition)
76001e04c3fSmrg
76101e04c3fSmrg      varset = VarSet()
76201e04c3fSmrg      if isinstance(search, Expression):
76301e04c3fSmrg         self.search = search
76401e04c3fSmrg      else:
76501e04c3fSmrg         self.search = Expression(search, "search{0}".format(self.id), varset)
76601e04c3fSmrg
76701e04c3fSmrg      varset.lock()
76801e04c3fSmrg
76901e04c3fSmrg      if isinstance(replace, Value):
77001e04c3fSmrg         self.replace = replace
77101e04c3fSmrg      else:
77201e04c3fSmrg         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
77301e04c3fSmrg
77401e04c3fSmrg      BitSizeValidator(varset).validate(self.search, self.replace)
77501e04c3fSmrg
7767e102996Smayaclass TreeAutomaton(object):
7777e102996Smaya   """This class calculates a bottom-up tree automaton to quickly search for
7787e102996Smaya   the left-hand sides of tranforms. Tree automatons are a generalization of
7797e102996Smaya   classical NFA's and DFA's, where the transition function determines the
7807e102996Smaya   state of the parent node based on the state of its children. We construct a
7817e102996Smaya   deterministic automaton to match patterns, using a similar algorithm to the
7827e102996Smaya   classical NFA to DFA construction. At the moment, it only matches opcodes
7837e102996Smaya   and constants (without checking the actual value), leaving more detailed
7847e102996Smaya   checking to the search function which actually checks the leaves. The
7857e102996Smaya   automaton acts as a quick filter for the search function, requiring only n
7867e102996Smaya   + 1 table lookups for each n-source operation. The implementation is based
7877e102996Smaya   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
7887e102996Smaya   In the language of that reference, this is a frontier-to-root deterministic
7897e102996Smaya   automaton using only symbol filtering. The filtering is crucial to reduce
7907e102996Smaya   both the time taken to generate the tables and the size of the tables.
7917e102996Smaya   """
7927e102996Smaya   def __init__(self, transforms):
7937e102996Smaya      self.patterns = [t.search for t in transforms]
7947e102996Smaya      self._compute_items()
7957e102996Smaya      self._build_table()
7967e102996Smaya      #print('num items: {}'.format(len(set(self.items.values()))))
7977e102996Smaya      #print('num states: {}'.format(len(self.states)))
7987e102996Smaya      #for state, patterns in zip(self.states, self.patterns):
7997e102996Smaya      #   print('{}: num patterns: {}'.format(state, len(patterns)))
8007e102996Smaya
8017e102996Smaya   class IndexMap(object):
8027e102996Smaya      """An indexed list of objects, where one can either lookup an object by
8037e102996Smaya      index or find the index associated to an object quickly using a hash
8047e102996Smaya      table. Compared to a list, it has a constant time index(). Compared to a
8057e102996Smaya      set, it provides a stable iteration order.
8067e102996Smaya      """
8077e102996Smaya      def __init__(self, iterable=()):
8087e102996Smaya         self.objects = []
8097e102996Smaya         self.map = {}
8107e102996Smaya         for obj in iterable:
8117e102996Smaya            self.add(obj)
8127e102996Smaya
8137e102996Smaya      def __getitem__(self, i):
8147e102996Smaya         return self.objects[i]
8157e102996Smaya
8167e102996Smaya      def __contains__(self, obj):
8177e102996Smaya         return obj in self.map
8187e102996Smaya
8197e102996Smaya      def __len__(self):
8207e102996Smaya         return len(self.objects)
8217e102996Smaya
8227e102996Smaya      def __iter__(self):
8237e102996Smaya         return iter(self.objects)
8247e102996Smaya
8257e102996Smaya      def clear(self):
8267e102996Smaya         self.objects = []
8277e102996Smaya         self.map.clear()
8287e102996Smaya
8297e102996Smaya      def index(self, obj):
8307e102996Smaya         return self.map[obj]
8317e102996Smaya
8327e102996Smaya      def add(self, obj):
8337e102996Smaya         if obj in self.map:
8347e102996Smaya            return self.map[obj]
8357e102996Smaya         else:
8367e102996Smaya            index = len(self.objects)
8377e102996Smaya            self.objects.append(obj)
8387e102996Smaya            self.map[obj] = index
8397e102996Smaya            return index
8407e102996Smaya
8417e102996Smaya      def __repr__(self):
8427e102996Smaya         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
8437e102996Smaya
8447e102996Smaya   class Item(object):
8457e102996Smaya      """This represents an "item" in the language of "Tree Automatons." This
8467e102996Smaya      is just a subtree of some pattern, which represents a potential partial
8477e102996Smaya      match at runtime. We deduplicate them, so that identical subtrees of
8487e102996Smaya      different patterns share the same object, and store some extra
8497e102996Smaya      information needed for the main algorithm as well.
8507e102996Smaya      """
8517e102996Smaya      def __init__(self, opcode, children):
8527e102996Smaya         self.opcode = opcode
8537e102996Smaya         self.children = children
8547e102996Smaya         # These are the indices of patterns for which this item is the root node.
8557e102996Smaya         self.patterns = []
8567e102996Smaya         # This the set of opcodes for parents of this item. Used to speed up
8577e102996Smaya         # filtering.
8587e102996Smaya         self.parent_ops = set()
8597e102996Smaya
8607e102996Smaya      def __str__(self):
8617e102996Smaya         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
8627e102996Smaya
8637e102996Smaya      def __repr__(self):
8647e102996Smaya         return str(self)
8657e102996Smaya
8667e102996Smaya   def _compute_items(self):
8677e102996Smaya      """Build a set of all possible items, deduplicating them."""
8687e102996Smaya      # This is a map from (opcode, sources) to item.
8697e102996Smaya      self.items = {}
8707e102996Smaya
8717e102996Smaya      # The set of all opcodes used by the patterns. Used later to avoid
8727e102996Smaya      # building and emitting all the tables for opcodes that aren't used.
8737e102996Smaya      self.opcodes = self.IndexMap()
8747e102996Smaya
8757e102996Smaya      def get_item(opcode, children, pattern=None):
8767ec681f3Smrg         commutative = len(children) >= 2 \
8777ec681f3Smrg               and "2src_commutative" in opcodes[opcode].algebraic_properties
8787e102996Smaya         item = self.items.setdefault((opcode, children),
8797e102996Smaya                                      self.Item(opcode, children))
8807e102996Smaya         if commutative:
8817ec681f3Smrg            self.items[opcode, (children[1], children[0]) + children[2:]] = item
8827e102996Smaya         if pattern is not None:
8837e102996Smaya            item.patterns.append(pattern)
8847e102996Smaya         return item
8857e102996Smaya
8867e102996Smaya      self.wildcard = get_item("__wildcard", ())
8877e102996Smaya      self.const = get_item("__const", ())
8887e102996Smaya
8897e102996Smaya      def process_subpattern(src, pattern=None):
8907e102996Smaya         if isinstance(src, Constant):
8917e102996Smaya            # Note: we throw away the actual constant value!
8927e102996Smaya            return self.const
8937e102996Smaya         elif isinstance(src, Variable):
8947e102996Smaya            if src.is_constant:
8957e102996Smaya               return self.const
8967e102996Smaya            else:
8977e102996Smaya               # Note: we throw away which variable it is here! This special
8987e102996Smaya               # item is equivalent to nu in "Tree Automatons."
8997e102996Smaya               return self.wildcard
9007e102996Smaya         else:
9017e102996Smaya            assert isinstance(src, Expression)
9027e102996Smaya            opcode = src.opcode
9037e102996Smaya            stripped = opcode.rstrip('0123456789')
9047e102996Smaya            if stripped in conv_opcode_types:
9057e102996Smaya               # Matches that use conversion opcodes with a specific type,
9067e102996Smaya               # like f2b1, are tricky.  Either we construct the automaton to
9077e102996Smaya               # match specific NIR opcodes like nir_op_f2b1, in which case we
9087e102996Smaya               # need to create separate items for each possible NIR opcode
9097e102996Smaya               # for patterns that have a generic opcode like f2b, or we
9107e102996Smaya               # construct it to match the search opcode, in which case we
9117e102996Smaya               # need to map f2b1 to f2b when constructing the automaton. Here
9127e102996Smaya               # we do the latter.
9137e102996Smaya               opcode = stripped
9147e102996Smaya            self.opcodes.add(opcode)
9157e102996Smaya            children = tuple(process_subpattern(c) for c in src.sources)
9167e102996Smaya            item = get_item(opcode, children, pattern)
9177e102996Smaya            for i, child in enumerate(children):
9187e102996Smaya               child.parent_ops.add(opcode)
9197e102996Smaya            return item
9207e102996Smaya
9217e102996Smaya      for i, pattern in enumerate(self.patterns):
9227e102996Smaya         process_subpattern(pattern, i)
9237e102996Smaya
9247e102996Smaya   def _build_table(self):
9257e102996Smaya      """This is the core algorithm which builds up the transition table. It
9267e102996Smaya      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
9277e102996Smaya      Comp_a and Filt_{a,i} using integers to identify match sets." It
9287e102996Smaya      simultaneously builds up a list of all possible "match sets" or
9297e102996Smaya      "states", where each match set represents the set of Item's that match a
9307e102996Smaya      given instruction, and builds up the transition table between states.
9317e102996Smaya      """
9327e102996Smaya      # Map from opcode + filtered state indices to transitioned state.
9337e102996Smaya      self.table = defaultdict(dict)
9347e102996Smaya      # Bijection from state to index. q in the original algorithm is
9357e102996Smaya      # len(self.states)
9367e102996Smaya      self.states = self.IndexMap()
9377e102996Smaya      # List of pattern matches for each state index.
9387e102996Smaya      self.state_patterns = []
9397e102996Smaya      # Map from state index to filtered state index for each opcode.
9407e102996Smaya      self.filter = defaultdict(list)
9417e102996Smaya      # Bijections from filtered state to filtered state index for each
9427e102996Smaya      # opcode, called the "representor sets" in the original algorithm.
9437e102996Smaya      # q_{a,j} in the original algorithm is len(self.rep[op]).
9447e102996Smaya      self.rep = defaultdict(self.IndexMap)
9457e102996Smaya
9467e102996Smaya      # Everything in self.states with a index at least worklist_index is part
9477e102996Smaya      # of the worklist of newly created states. There is also a worklist of
9487e102996Smaya      # newly fitered states for each opcode, for which worklist_indices
9497e102996Smaya      # serves a similar purpose. worklist_index corresponds to p in the
9507e102996Smaya      # original algorithm, while worklist_indices is p_{a,j} (although since
9517e102996Smaya      # we only filter by opcode/symbol, it's really just p_a).
9527e102996Smaya      self.worklist_index = 0
9537e102996Smaya      worklist_indices = defaultdict(lambda: 0)
9547e102996Smaya
9557e102996Smaya      # This is the set of opcodes for which the filtered worklist is non-empty.
9567e102996Smaya      # It's used to avoid scanning opcodes for which there is nothing to
9577e102996Smaya      # process when building the transition table. It corresponds to new_a in
9587e102996Smaya      # the original algorithm.
9597e102996Smaya      new_opcodes = self.IndexMap()
9607e102996Smaya
9617e102996Smaya      # Process states on the global worklist, filtering them for each opcode,
9627e102996Smaya      # updating the filter tables, and updating the filtered worklists if any
9637e102996Smaya      # new filtered states are found. Similar to ComputeRepresenterSets() in
9647e102996Smaya      # the original algorithm, although that only processes a single state.
9657e102996Smaya      def process_new_states():
9667e102996Smaya         while self.worklist_index < len(self.states):
9677e102996Smaya            state = self.states[self.worklist_index]
9687e102996Smaya
9697e102996Smaya            # Calculate pattern matches for this state. Each pattern is
9707e102996Smaya            # assigned to a unique item, so we don't have to worry about
9717e102996Smaya            # deduplicating them here. However, we do have to sort them so
9727e102996Smaya            # that they're visited at runtime in the order they're specified
9737e102996Smaya            # in the source.
9747e102996Smaya            patterns = list(sorted(p for item in state for p in item.patterns))
9757e102996Smaya            assert len(self.state_patterns) == self.worklist_index
9767e102996Smaya            self.state_patterns.append(patterns)
9777e102996Smaya
9787e102996Smaya            # calculate filter table for this state, and update filtered
9797e102996Smaya            # worklists.
9807e102996Smaya            for op in self.opcodes:
9817e102996Smaya               filt = self.filter[op]
9827e102996Smaya               rep = self.rep[op]
9837e102996Smaya               filtered = frozenset(item for item in state if \
9847e102996Smaya                  op in item.parent_ops)
9857e102996Smaya               if filtered in rep:
9867e102996Smaya                  rep_index = rep.index(filtered)
9877e102996Smaya               else:
9887e102996Smaya                  rep_index = rep.add(filtered)
9897e102996Smaya                  new_opcodes.add(op)
9907e102996Smaya               assert len(filt) == self.worklist_index
9917e102996Smaya               filt.append(rep_index)
9927e102996Smaya            self.worklist_index += 1
9937e102996Smaya
9947e102996Smaya      # There are two start states: one which can only match as a wildcard,
9957e102996Smaya      # and one which can match as a wildcard or constant. These will be the
9967e102996Smaya      # states of intrinsics/other instructions and load_const instructions,
9977e102996Smaya      # respectively. The indices of these must match the definitions of
9987e102996Smaya      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
9997e102996Smaya      # initialize things correctly.
10007e102996Smaya      self.states.add(frozenset((self.wildcard,)))
10017e102996Smaya      self.states.add(frozenset((self.const,self.wildcard)))
10027e102996Smaya      process_new_states()
10037e102996Smaya
10047e102996Smaya      while len(new_opcodes) > 0:
10057e102996Smaya         for op in new_opcodes:
10067e102996Smaya            rep = self.rep[op]
10077e102996Smaya            table = self.table[op]
10087e102996Smaya            op_worklist_index = worklist_indices[op]
10097e102996Smaya            if op in conv_opcode_types:
10107e102996Smaya               num_srcs = 1
10117e102996Smaya            else:
10127e102996Smaya               num_srcs = opcodes[op].num_inputs
10137e102996Smaya
10147e102996Smaya            # Iterate over all possible source combinations where at least one
10157e102996Smaya            # is on the worklist.
10167e102996Smaya            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
10177e102996Smaya               if all(src_idx < op_worklist_index for src_idx in src_indices):
10187e102996Smaya                  continue
10197e102996Smaya
10207e102996Smaya               srcs = tuple(rep[src_idx] for src_idx in src_indices)
10217e102996Smaya
10227e102996Smaya               # Try all possible pairings of source items and add the
10237e102996Smaya               # corresponding parent items. This is Comp_a from the paper.
10247e102996Smaya               parent = set(self.items[op, item_srcs] for item_srcs in
10257e102996Smaya                  itertools.product(*srcs) if (op, item_srcs) in self.items)
10267e102996Smaya
10277e102996Smaya               # We could always start matching something else with a
10287e102996Smaya               # wildcard. This is Cl from the paper.
10297e102996Smaya               parent.add(self.wildcard)
10307e102996Smaya
10317e102996Smaya               table[src_indices] = self.states.add(frozenset(parent))
10327e102996Smaya            worklist_indices[op] = len(rep)
10337e102996Smaya         new_opcodes.clear()
10347e102996Smaya         process_new_states()
10357e102996Smaya
103601e04c3fSmrg_algebraic_pass_template = mako.template.Template("""
103701e04c3fSmrg#include "nir.h"
103801e04c3fSmrg#include "nir_builder.h"
103901e04c3fSmrg#include "nir_search.h"
104001e04c3fSmrg#include "nir_search_helpers.h"
104101e04c3fSmrg
10427ec681f3Smrg/* What follows is NIR algebraic transform code for the following ${len(xforms)}
10437ec681f3Smrg * transforms:
10447ec681f3Smrg% for xform in xforms:
10457ec681f3Smrg *    ${xform.search} => ${xform.replace}
10467ec681f3Smrg% endfor
10477e102996Smaya */
10487e102996Smaya
10497e102996Smaya<% cache = {} %>
10507e102996Smaya% for xform in xforms:
10517e102996Smaya   ${xform.search.render(cache)}
10527e102996Smaya   ${xform.replace.render(cache)}
105301e04c3fSmrg% endfor
105401e04c3fSmrg
10557e102996Smaya% for state_id, state_xforms in enumerate(automaton.state_patterns):
10567e102996Smaya% if state_xforms: # avoid emitting a 0-length array for MSVC
10577e102996Smayastatic const struct transform ${pass_name}_state${state_id}_xforms[] = {
10587e102996Smaya% for i in state_xforms:
10597e102996Smaya  { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
106001e04c3fSmrg% endfor
106101e04c3fSmrg};
10627e102996Smaya% endif
106301e04c3fSmrg% endfor
106401e04c3fSmrg
10657e102996Smayastatic const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
10667e102996Smaya% for op in automaton.opcodes:
10677e102996Smaya   [${get_c_opcode(op)}] = {
10687e102996Smaya      .filter = (uint16_t []) {
10697e102996Smaya      % for e in automaton.filter[op]:
10707e102996Smaya         ${e},
10717e102996Smaya      % endfor
10727e102996Smaya      },
10737e102996Smaya      <%
10747e102996Smaya        num_filtered = len(automaton.rep[op])
10757e102996Smaya      %>
10767e102996Smaya      .num_filtered_states = ${num_filtered},
10777e102996Smaya      .table = (uint16_t []) {
10787e102996Smaya      <%
10797e102996Smaya        num_srcs = len(next(iter(automaton.table[op])))
10807e102996Smaya      %>
10817e102996Smaya      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
10827e102996Smaya         ${automaton.table[op][indices]},
10837e102996Smaya      % endfor
10847e102996Smaya      },
10857e102996Smaya   },
10867e102996Smaya% endfor
10877e102996Smaya};
10887e102996Smaya
10897ec681f3Smrgconst struct transform *${pass_name}_transforms[] = {
10907e102996Smaya% for i in range(len(automaton.state_patterns)):
10917ec681f3Smrg   % if automaton.state_patterns[i]:
10927ec681f3Smrg   ${pass_name}_state${i}_xforms,
10937ec681f3Smrg   % else:
10947ec681f3Smrg   NULL,
10957ec681f3Smrg   % endif
10967e102996Smaya% endfor
10977ec681f3Smrg};
109801e04c3fSmrg
10997ec681f3Smrgconst uint16_t ${pass_name}_transform_counts[] = {
11007ec681f3Smrg% for i in range(len(automaton.state_patterns)):
11017ec681f3Smrg   % if automaton.state_patterns[i]:
11027ec681f3Smrg   (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
11037ec681f3Smrg   % else:
11047ec681f3Smrg   0,
11057ec681f3Smrg   % endif
11067ec681f3Smrg% endfor
11077ec681f3Smrg};
110801e04c3fSmrg
110901e04c3fSmrgbool
111001e04c3fSmrg${pass_name}(nir_shader *shader)
111101e04c3fSmrg{
111201e04c3fSmrg   bool progress = false;
111301e04c3fSmrg   bool condition_flags[${len(condition_list)}];
111401e04c3fSmrg   const nir_shader_compiler_options *options = shader->options;
11157e102996Smaya   const shader_info *info = &shader->info;
111601e04c3fSmrg   (void) options;
11177e102996Smaya   (void) info;
111801e04c3fSmrg
111901e04c3fSmrg   % for index, condition in enumerate(condition_list):
112001e04c3fSmrg   condition_flags[${index}] = ${condition};
112101e04c3fSmrg   % endfor
112201e04c3fSmrg
112301e04c3fSmrg   nir_foreach_function(function, shader) {
11247ec681f3Smrg      if (function->impl) {
11257ec681f3Smrg         progress |= nir_algebraic_impl(function->impl, condition_flags,
11267ec681f3Smrg                                        ${pass_name}_transforms,
11277ec681f3Smrg                                        ${pass_name}_transform_counts,
11287ec681f3Smrg                                        ${pass_name}_table);
11297ec681f3Smrg      }
113001e04c3fSmrg   }
113101e04c3fSmrg
113201e04c3fSmrg   return progress;
113301e04c3fSmrg}
113401e04c3fSmrg""")
113501e04c3fSmrg
11367e102996Smaya
113701e04c3fSmrgclass AlgebraicPass(object):
113801e04c3fSmrg   def __init__(self, pass_name, transforms):
11397e102996Smaya      self.xforms = []
11407e102996Smaya      self.opcode_xforms = defaultdict(lambda : [])
114101e04c3fSmrg      self.pass_name = pass_name
114201e04c3fSmrg
114301e04c3fSmrg      error = False
114401e04c3fSmrg
114501e04c3fSmrg      for xform in transforms:
114601e04c3fSmrg         if not isinstance(xform, SearchAndReplace):
114701e04c3fSmrg            try:
114801e04c3fSmrg               xform = SearchAndReplace(xform)
114901e04c3fSmrg            except:
115001e04c3fSmrg               print("Failed to parse transformation:", file=sys.stderr)
115101e04c3fSmrg               print("  " + str(xform), file=sys.stderr)
115201e04c3fSmrg               traceback.print_exc(file=sys.stderr)
115301e04c3fSmrg               print('', file=sys.stderr)
115401e04c3fSmrg               error = True
115501e04c3fSmrg               continue
115601e04c3fSmrg
11577e102996Smaya         self.xforms.append(xform)
11587e102996Smaya         if xform.search.opcode in conv_opcode_types:
11597e102996Smaya            dst_type = conv_opcode_types[xform.search.opcode]
11607e102996Smaya            for size in type_sizes(dst_type):
11617e102996Smaya               sized_opcode = xform.search.opcode + str(size)
11627e102996Smaya               self.opcode_xforms[sized_opcode].append(xform)
11637e102996Smaya         else:
11647e102996Smaya            self.opcode_xforms[xform.search.opcode].append(xform)
116501e04c3fSmrg
11667ec681f3Smrg         # Check to make sure the search pattern does not unexpectedly contain
11677ec681f3Smrg         # more commutative expressions than match_expression (nir_search.c)
11687ec681f3Smrg         # can handle.
11697ec681f3Smrg         comm_exprs = xform.search.comm_exprs
11707ec681f3Smrg
11717ec681f3Smrg         if xform.search.many_commutative_expressions:
11727ec681f3Smrg            if comm_exprs <= nir_search_max_comm_ops:
11737ec681f3Smrg               print("Transform expected to have too many commutative " \
11747ec681f3Smrg                     "expression but did not " \
11757ec681f3Smrg                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
11767ec681f3Smrg                     file=sys.stderr)
11777ec681f3Smrg               print("  " + str(xform), file=sys.stderr)
11787ec681f3Smrg               traceback.print_exc(file=sys.stderr)
11797ec681f3Smrg               print('', file=sys.stderr)
11807ec681f3Smrg               error = True
11817ec681f3Smrg         else:
11827ec681f3Smrg            if comm_exprs > nir_search_max_comm_ops:
11837ec681f3Smrg               print("Transformation with too many commutative expressions " \
11847ec681f3Smrg                     "({} > {}).  Modify pattern or annotate with " \
11857ec681f3Smrg                     "\"many-comm-expr\".".format(comm_exprs,
11867ec681f3Smrg                                                  nir_search_max_comm_ops),
11877ec681f3Smrg                     file=sys.stderr)
11887ec681f3Smrg               print("  " + str(xform.search), file=sys.stderr)
11897ec681f3Smrg               print("{}".format(xform.search.cond), file=sys.stderr)
11907ec681f3Smrg               error = True
11917ec681f3Smrg
11927e102996Smaya      self.automaton = TreeAutomaton(self.xforms)
119301e04c3fSmrg
119401e04c3fSmrg      if error:
119501e04c3fSmrg         sys.exit(1)
119601e04c3fSmrg
11977e102996Smaya
119801e04c3fSmrg   def render(self):
119901e04c3fSmrg      return _algebraic_pass_template.render(pass_name=self.pass_name,
12007e102996Smaya                                             xforms=self.xforms,
12017e102996Smaya                                             opcode_xforms=self.opcode_xforms,
12027e102996Smaya                                             condition_list=condition_list,
12037e102996Smaya                                             automaton=self.automaton,
12047e102996Smaya                                             get_c_opcode=get_c_opcode,
12057e102996Smaya                                             itertools=itertools)
1206