101e04c3fSmrg# 201e04c3fSmrg# Copyright (C) 2014 Intel Corporation 301e04c3fSmrg# 401e04c3fSmrg# Permission is hereby granted, free of charge, to any person obtaining a 501e04c3fSmrg# copy of this software and associated documentation files (the "Software"), 601e04c3fSmrg# to deal in the Software without restriction, including without limitation 701e04c3fSmrg# the rights to use, copy, modify, merge, publish, distribute, sublicense, 801e04c3fSmrg# and/or sell copies of the Software, and to permit persons to whom the 901e04c3fSmrg# Software is furnished to do so, subject to the following conditions: 1001e04c3fSmrg# 1101e04c3fSmrg# The above copyright notice and this permission notice (including the next 1201e04c3fSmrg# paragraph) shall be included in all copies or substantial portions of the 1301e04c3fSmrg# Software. 1401e04c3fSmrg# 1501e04c3fSmrg# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1601e04c3fSmrg# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1701e04c3fSmrg# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 1801e04c3fSmrg# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 1901e04c3fSmrg# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 2001e04c3fSmrg# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 2101e04c3fSmrg# IN THE SOFTWARE. 2201e04c3fSmrg# 2301e04c3fSmrg# Authors: 2401e04c3fSmrg# Jason Ekstrand (jason@jlekstrand.net) 2501e04c3fSmrg 2601e04c3fSmrgimport ast 277e102996Smayafrom collections import defaultdict 2801e04c3fSmrgimport itertools 2901e04c3fSmrgimport struct 3001e04c3fSmrgimport sys 3101e04c3fSmrgimport mako.template 3201e04c3fSmrgimport re 3301e04c3fSmrgimport traceback 3401e04c3fSmrg 357e102996Smayafrom nir_opcodes import opcodes, type_sizes 367e102996Smaya 377ec681f3Smrg# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 387ec681f3Smrgnir_search_max_comm_ops = 8 397ec681f3Smrg 407e102996Smaya# These opcodes are only employed by nir_search. This provides a mapping from 417e102996Smaya# opcode to destination type. 427e102996Smayaconv_opcode_types = { 437e102996Smaya 'i2f' : 'float', 447e102996Smaya 'u2f' : 'float', 457e102996Smaya 'f2f' : 'float', 467e102996Smaya 'f2u' : 'uint', 477e102996Smaya 'f2i' : 'int', 487e102996Smaya 'u2u' : 'uint', 497e102996Smaya 'i2i' : 'int', 507e102996Smaya 'b2f' : 'float', 517e102996Smaya 'b2i' : 'int', 527e102996Smaya 'i2b' : 'bool', 537e102996Smaya 'f2b' : 'bool', 547e102996Smaya} 557e102996Smaya 567e102996Smayadef get_c_opcode(op): 577e102996Smaya if op in conv_opcode_types: 587e102996Smaya return 'nir_search_op_' + op 597e102996Smaya else: 607e102996Smaya return 'nir_op_' + op 617e102996Smaya 6201e04c3fSmrg_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 6301e04c3fSmrg 6401e04c3fSmrgdef type_bits(type_str): 6501e04c3fSmrg m = _type_re.match(type_str) 6601e04c3fSmrg assert m.group('type') 6701e04c3fSmrg 6801e04c3fSmrg if m.group('bits') is None: 6901e04c3fSmrg return 0 7001e04c3fSmrg else: 7101e04c3fSmrg return int(m.group('bits')) 7201e04c3fSmrg 7301e04c3fSmrg# Represents a set of variables, each with a unique id 7401e04c3fSmrgclass VarSet(object): 7501e04c3fSmrg def __init__(self): 7601e04c3fSmrg self.names = {} 7701e04c3fSmrg self.ids = itertools.count() 7801e04c3fSmrg self.immutable = False; 7901e04c3fSmrg 8001e04c3fSmrg def __getitem__(self, name): 8101e04c3fSmrg if name not in self.names: 8201e04c3fSmrg assert not self.immutable, "Unknown replacement variable: " + name 8301e04c3fSmrg self.names[name] = next(self.ids) 8401e04c3fSmrg 8501e04c3fSmrg return self.names[name] 8601e04c3fSmrg 8701e04c3fSmrg def lock(self): 8801e04c3fSmrg self.immutable = True 8901e04c3fSmrg 9001e04c3fSmrgclass Value(object): 9101e04c3fSmrg @staticmethod 9201e04c3fSmrg def create(val, name_base, varset): 9301e04c3fSmrg if isinstance(val, bytes): 9401e04c3fSmrg val = val.decode('utf-8') 9501e04c3fSmrg 9601e04c3fSmrg if isinstance(val, tuple): 9701e04c3fSmrg return Expression(val, name_base, varset) 9801e04c3fSmrg elif isinstance(val, Expression): 9901e04c3fSmrg return val 1007ec681f3Smrg elif isinstance(val, str): 10101e04c3fSmrg return Variable(val, name_base, varset) 1027ec681f3Smrg elif isinstance(val, (bool, float, int)): 10301e04c3fSmrg return Constant(val, name_base) 10401e04c3fSmrg 10501e04c3fSmrg def __init__(self, val, name, type_str): 10601e04c3fSmrg self.in_val = str(val) 10701e04c3fSmrg self.name = name 10801e04c3fSmrg self.type_str = type_str 10901e04c3fSmrg 11001e04c3fSmrg def __str__(self): 11101e04c3fSmrg return self.in_val 11201e04c3fSmrg 1137e102996Smaya def get_bit_size(self): 1147e102996Smaya """Get the physical bit-size that has been chosen for this value, or if 1157e102996Smaya there is none, the canonical value which currently represents this 1167e102996Smaya bit-size class. Variables will be preferred, i.e. if there are any 1177e102996Smaya variables in the equivalence class, the canonical value will be a 1187e102996Smaya variable. We do this since we'll need to know which variable each value 1197e102996Smaya is equivalent to when constructing the replacement expression. This is 1207e102996Smaya the "find" part of the union-find algorithm. 1217e102996Smaya """ 1227e102996Smaya bit_size = self 1237e102996Smaya 1247e102996Smaya while isinstance(bit_size, Value): 1257e102996Smaya if bit_size._bit_size is None: 1267e102996Smaya break 1277e102996Smaya bit_size = bit_size._bit_size 1287e102996Smaya 1297e102996Smaya if bit_size is not self: 1307e102996Smaya self._bit_size = bit_size 1317e102996Smaya return bit_size 1327e102996Smaya 1337e102996Smaya def set_bit_size(self, other): 1347e102996Smaya """Make self.get_bit_size() return what other.get_bit_size() return 1357e102996Smaya before calling this, or just "other" if it's a concrete bit-size. This is 1367e102996Smaya the "union" part of the union-find algorithm. 1377e102996Smaya """ 1387e102996Smaya 1397e102996Smaya self_bit_size = self.get_bit_size() 1407e102996Smaya other_bit_size = other if isinstance(other, int) else other.get_bit_size() 1417e102996Smaya 1427e102996Smaya if self_bit_size == other_bit_size: 1437e102996Smaya return 1447e102996Smaya 1457e102996Smaya self_bit_size._bit_size = other_bit_size 1467e102996Smaya 14701e04c3fSmrg @property 14801e04c3fSmrg def type_enum(self): 14901e04c3fSmrg return "nir_search_value_" + self.type_str 15001e04c3fSmrg 15101e04c3fSmrg @property 15201e04c3fSmrg def c_type(self): 15301e04c3fSmrg return "nir_search_" + self.type_str 15401e04c3fSmrg 1557e102996Smaya def __c_name(self, cache): 1567e102996Smaya if cache is not None and self.name in cache: 1577e102996Smaya return cache[self.name] 1587e102996Smaya else: 1597e102996Smaya return self.name 1607e102996Smaya 1617e102996Smaya def c_value_ptr(self, cache): 1627e102996Smaya return "&{0}.value".format(self.__c_name(cache)) 1637e102996Smaya 1647e102996Smaya def c_ptr(self, cache): 1657e102996Smaya return "&{0}".format(self.__c_name(cache)) 1667e102996Smaya 16701e04c3fSmrg @property 1687e102996Smaya def c_bit_size(self): 1697e102996Smaya bit_size = self.get_bit_size() 1707e102996Smaya if isinstance(bit_size, int): 1717e102996Smaya return bit_size 1727e102996Smaya elif isinstance(bit_size, Variable): 1737e102996Smaya return -bit_size.index - 1 1747e102996Smaya else: 1757e102996Smaya # If the bit-size class is neither a variable, nor an actual bit-size, then 1767e102996Smaya # - If it's in the search expression, we don't need to check anything 1777e102996Smaya # - If it's in the replace expression, either it's ambiguous (in which 1787e102996Smaya # case we'd reject it), or it equals the bit-size of the search value 1797e102996Smaya # We represent these cases with a 0 bit-size. 1807e102996Smaya return 0 1817e102996Smaya 1827e102996Smaya __template = mako.template.Template("""{ 1837e102996Smaya { ${val.type_enum}, ${val.c_bit_size} }, 1847e102996Smaya% if isinstance(val, Constant): 1857e102996Smaya ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 1867e102996Smaya% elif isinstance(val, Variable): 1877e102996Smaya ${val.index}, /* ${val.var_name} */ 1887e102996Smaya ${'true' if val.is_constant else 'false'}, 1897e102996Smaya ${val.type() or 'nir_type_invalid' }, 1907e102996Smaya ${val.cond if val.cond else 'NULL'}, 1917ec681f3Smrg ${val.swizzle()}, 1927e102996Smaya% elif isinstance(val, Expression): 1937ec681f3Smrg ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'}, 1947e102996Smaya ${val.comm_expr_idx}, ${val.comm_exprs}, 1957e102996Smaya ${val.c_opcode()}, 1967e102996Smaya { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} }, 1977e102996Smaya ${val.cond if val.cond else 'NULL'}, 1987e102996Smaya% endif 1997e102996Smaya};""") 20001e04c3fSmrg 2017e102996Smaya def render(self, cache): 2027e102996Smaya struct_init = self.__template.render(val=self, cache=cache, 2037e102996Smaya Constant=Constant, 2047e102996Smaya Variable=Variable, 2057e102996Smaya Expression=Expression) 2067e102996Smaya if cache is not None and struct_init in cache: 2077e102996Smaya # If it's in the cache, register a name remap in the cache and render 2087e102996Smaya # only a comment saying it's been remapped 2097e102996Smaya cache[self.name] = cache[struct_init] 2107e102996Smaya return "/* {} -> {} in the cache */\n".format(self.name, 2117e102996Smaya cache[struct_init]) 2127e102996Smaya else: 2137e102996Smaya if cache is not None: 2147e102996Smaya cache[struct_init] = self.name 2157e102996Smaya return "static const {} {} = {}\n".format(self.c_type, self.name, 2167e102996Smaya struct_init) 21701e04c3fSmrg 21801e04c3fSmrg_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 21901e04c3fSmrg 22001e04c3fSmrgclass Constant(Value): 22101e04c3fSmrg def __init__(self, val, name): 22201e04c3fSmrg Value.__init__(self, val, name, "constant") 22301e04c3fSmrg 22401e04c3fSmrg if isinstance(val, (str)): 22501e04c3fSmrg m = _constant_re.match(val) 22601e04c3fSmrg self.value = ast.literal_eval(m.group('value')) 2277e102996Smaya self._bit_size = int(m.group('bits')) if m.group('bits') else None 22801e04c3fSmrg else: 22901e04c3fSmrg self.value = val 2307e102996Smaya self._bit_size = None 23101e04c3fSmrg 23201e04c3fSmrg if isinstance(self.value, bool): 2337e102996Smaya assert self._bit_size is None or self._bit_size == 1 2347e102996Smaya self._bit_size = 1 23501e04c3fSmrg 23601e04c3fSmrg def hex(self): 23701e04c3fSmrg if isinstance(self.value, (bool)): 23801e04c3fSmrg return 'NIR_TRUE' if self.value else 'NIR_FALSE' 2397ec681f3Smrg if isinstance(self.value, int): 24001e04c3fSmrg return hex(self.value) 24101e04c3fSmrg elif isinstance(self.value, float): 2427ec681f3Smrg return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) 24301e04c3fSmrg else: 24401e04c3fSmrg assert False 24501e04c3fSmrg 24601e04c3fSmrg def type(self): 24701e04c3fSmrg if isinstance(self.value, (bool)): 24801e04c3fSmrg return "nir_type_bool" 2497ec681f3Smrg elif isinstance(self.value, int): 25001e04c3fSmrg return "nir_type_int" 25101e04c3fSmrg elif isinstance(self.value, float): 25201e04c3fSmrg return "nir_type_float" 25301e04c3fSmrg 2547ec681f3Smrg def equivalent(self, other): 2557ec681f3Smrg """Check that two constants are equivalent. 2567ec681f3Smrg 2577ec681f3Smrg This is check is much weaker than equality. One generally cannot be 2587ec681f3Smrg used in place of the other. Using this implementation for the __eq__ 2597ec681f3Smrg will break BitSizeValidator. 2607ec681f3Smrg 2617ec681f3Smrg """ 2627ec681f3Smrg if not isinstance(other, type(self)): 2637ec681f3Smrg return False 2647ec681f3Smrg 2657ec681f3Smrg return self.value == other.value 2667ec681f3Smrg 2677ec681f3Smrg# The $ at the end forces there to be an error if any part of the string 2687ec681f3Smrg# doesn't match one of the field patterns. 26901e04c3fSmrg_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 27001e04c3fSmrg r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 2717ec681f3Smrg r"(?P<cond>\([^\)]+\))?" 2727ec681f3Smrg r"(?P<swiz>\.[xyzw]+)?" 2737ec681f3Smrg r"$") 27401e04c3fSmrg 27501e04c3fSmrgclass Variable(Value): 27601e04c3fSmrg def __init__(self, val, name, varset): 27701e04c3fSmrg Value.__init__(self, val, name, "variable") 27801e04c3fSmrg 27901e04c3fSmrg m = _var_name_re.match(val) 2807ec681f3Smrg assert m and m.group('name') is not None, \ 2817ec681f3Smrg "Malformed variable name \"{}\".".format(val) 28201e04c3fSmrg 28301e04c3fSmrg self.var_name = m.group('name') 2847e102996Smaya 2857e102996Smaya # Prevent common cases where someone puts quotes around a literal 2867e102996Smaya # constant. If we want to support names that have numeric or 2877e102996Smaya # punctuation characters, we can me the first assertion more flexible. 2887e102996Smaya assert self.var_name.isalpha() 2897ec681f3Smrg assert self.var_name != 'True' 2907ec681f3Smrg assert self.var_name != 'False' 2917e102996Smaya 29201e04c3fSmrg self.is_constant = m.group('const') is not None 29301e04c3fSmrg self.cond = m.group('cond') 29401e04c3fSmrg self.required_type = m.group('type') 2957e102996Smaya self._bit_size = int(m.group('bits')) if m.group('bits') else None 2967ec681f3Smrg self.swiz = m.group('swiz') 29701e04c3fSmrg 29801e04c3fSmrg if self.required_type == 'bool': 2997e102996Smaya if self._bit_size is not None: 3007e102996Smaya assert self._bit_size in type_sizes(self.required_type) 3017e102996Smaya else: 3027e102996Smaya self._bit_size = 1 30301e04c3fSmrg 30401e04c3fSmrg if self.required_type is not None: 30501e04c3fSmrg assert self.required_type in ('float', 'bool', 'int', 'uint') 30601e04c3fSmrg 30701e04c3fSmrg self.index = varset[self.var_name] 30801e04c3fSmrg 30901e04c3fSmrg def type(self): 31001e04c3fSmrg if self.required_type == 'bool': 31101e04c3fSmrg return "nir_type_bool" 31201e04c3fSmrg elif self.required_type in ('int', 'uint'): 31301e04c3fSmrg return "nir_type_int" 31401e04c3fSmrg elif self.required_type == 'float': 31501e04c3fSmrg return "nir_type_float" 31601e04c3fSmrg 3177ec681f3Smrg def equivalent(self, other): 3187ec681f3Smrg """Check that two variables are equivalent. 3197ec681f3Smrg 3207ec681f3Smrg This is check is much weaker than equality. One generally cannot be 3217ec681f3Smrg used in place of the other. Using this implementation for the __eq__ 3227ec681f3Smrg will break BitSizeValidator. 3237ec681f3Smrg 3247ec681f3Smrg """ 3257ec681f3Smrg if not isinstance(other, type(self)): 3267ec681f3Smrg return False 3277ec681f3Smrg 3287ec681f3Smrg return self.index == other.index 3297ec681f3Smrg 3307ec681f3Smrg def swizzle(self): 3317ec681f3Smrg if self.swiz is not None: 3327ec681f3Smrg swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 3337ec681f3Smrg 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 3347ec681f3Smrg 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 3357ec681f3Smrg 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 3367ec681f3Smrg 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 3377ec681f3Smrg return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 3387ec681f3Smrg return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 3397ec681f3Smrg 3407ec681f3Smrg_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 34101e04c3fSmrg r"(?P<cond>\([^\)]+\))?") 34201e04c3fSmrg 34301e04c3fSmrgclass Expression(Value): 34401e04c3fSmrg def __init__(self, expr, name_base, varset): 34501e04c3fSmrg Value.__init__(self, expr, name_base, "expression") 34601e04c3fSmrg assert isinstance(expr, tuple) 34701e04c3fSmrg 34801e04c3fSmrg m = _opcode_re.match(expr[0]) 34901e04c3fSmrg assert m and m.group('opcode') is not None 35001e04c3fSmrg 35101e04c3fSmrg self.opcode = m.group('opcode') 3527e102996Smaya self._bit_size = int(m.group('bits')) if m.group('bits') else None 35301e04c3fSmrg self.inexact = m.group('inexact') is not None 3547ec681f3Smrg self.exact = m.group('exact') is not None 35501e04c3fSmrg self.cond = m.group('cond') 3567ec681f3Smrg 3577ec681f3Smrg assert not self.inexact or not self.exact, \ 3587ec681f3Smrg 'Expression cannot be both exact and inexact.' 3597ec681f3Smrg 3607ec681f3Smrg # "many-comm-expr" isn't really a condition. It's notification to the 3617ec681f3Smrg # generator that this pattern is known to have too many commutative 3627ec681f3Smrg # expressions, and an error should not be generated for this case. 3637ec681f3Smrg self.many_commutative_expressions = False 3647ec681f3Smrg if self.cond and self.cond.find("many-comm-expr") >= 0: 3657ec681f3Smrg # Split the condition into a comma-separated list. Remove 3667ec681f3Smrg # "many-comm-expr". If there is anything left, put it back together. 3677ec681f3Smrg c = self.cond[1:-1].split(",") 3687ec681f3Smrg c.remove("many-comm-expr") 3697ec681f3Smrg 3707ec681f3Smrg self.cond = "({})".format(",".join(c)) if c else None 3717ec681f3Smrg self.many_commutative_expressions = True 3727ec681f3Smrg 37301e04c3fSmrg self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) 37401e04c3fSmrg for (i, src) in enumerate(expr[1:]) ] 37501e04c3fSmrg 3767ec681f3Smrg # nir_search_expression::srcs is hard-coded to 4 3777ec681f3Smrg assert len(self.sources) <= 4 3787ec681f3Smrg 3797e102996Smaya if self.opcode in conv_opcode_types: 3807e102996Smaya assert self._bit_size is None, \ 3817e102996Smaya 'Expression cannot use an unsized conversion opcode with ' \ 3827e102996Smaya 'an explicit size; that\'s silly.' 3837e102996Smaya 3847e102996Smaya self.__index_comm_exprs(0) 3857e102996Smaya 3867ec681f3Smrg def equivalent(self, other): 3877ec681f3Smrg """Check that two variables are equivalent. 3887ec681f3Smrg 3897ec681f3Smrg This is check is much weaker than equality. One generally cannot be 3907ec681f3Smrg used in place of the other. Using this implementation for the __eq__ 3917ec681f3Smrg will break BitSizeValidator. 3927ec681f3Smrg 3937ec681f3Smrg This implementation does not check for equivalence due to commutativity, 3947ec681f3Smrg but it could. 3957ec681f3Smrg 3967ec681f3Smrg """ 3977ec681f3Smrg if not isinstance(other, type(self)): 3987ec681f3Smrg return False 3997ec681f3Smrg 4007ec681f3Smrg if len(self.sources) != len(other.sources): 4017ec681f3Smrg return False 4027ec681f3Smrg 4037ec681f3Smrg if self.opcode != other.opcode: 4047ec681f3Smrg return False 4057ec681f3Smrg 4067ec681f3Smrg return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 4077ec681f3Smrg 4087e102996Smaya def __index_comm_exprs(self, base_idx): 4097e102996Smaya """Recursively count and index commutative expressions 4107e102996Smaya """ 4117e102996Smaya self.comm_exprs = 0 4127ec681f3Smrg 4137ec681f3Smrg # A note about the explicit "len(self.sources)" check. The list of 4147ec681f3Smrg # sources comes from user input, and that input might be bad. Check 4157ec681f3Smrg # that the expected second source exists before accessing it. Without 4167ec681f3Smrg # this check, a unit test that does "('iadd', 'a')" will crash. 4177e102996Smaya if self.opcode not in conv_opcode_types and \ 4187ec681f3Smrg "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 4197ec681f3Smrg len(self.sources) >= 2 and \ 4207ec681f3Smrg not self.sources[0].equivalent(self.sources[1]): 4217e102996Smaya self.comm_expr_idx = base_idx 4227e102996Smaya self.comm_exprs += 1 42301e04c3fSmrg else: 4247e102996Smaya self.comm_expr_idx = -1 42501e04c3fSmrg 4267e102996Smaya for s in self.sources: 4277e102996Smaya if isinstance(s, Expression): 4287e102996Smaya s.__index_comm_exprs(base_idx + self.comm_exprs) 4297e102996Smaya self.comm_exprs += s.comm_exprs 43001e04c3fSmrg 4317e102996Smaya return self.comm_exprs 43201e04c3fSmrg 4337e102996Smaya def c_opcode(self): 4347e102996Smaya return get_c_opcode(self.opcode) 4357e102996Smaya 4367e102996Smaya def render(self, cache): 4377e102996Smaya srcs = "\n".join(src.render(cache) for src in self.sources) 4387e102996Smaya return srcs + super(Expression, self).render(cache) 43901e04c3fSmrg 44001e04c3fSmrgclass BitSizeValidator(object): 44101e04c3fSmrg """A class for validating bit sizes of expressions. 44201e04c3fSmrg 44301e04c3fSmrg NIR supports multiple bit-sizes on expressions in order to handle things 44401e04c3fSmrg such as fp64. The source and destination of every ALU operation is 44501e04c3fSmrg assigned a type and that type may or may not specify a bit size. Sources 44601e04c3fSmrg and destinations whose type does not specify a bit size are considered 44701e04c3fSmrg "unsized" and automatically take on the bit size of the corresponding 44801e04c3fSmrg register or SSA value. NIR has two simple rules for bit sizes that are 44901e04c3fSmrg validated by nir_validator: 45001e04c3fSmrg 45101e04c3fSmrg 1) A given SSA def or register has a single bit size that is respected by 45201e04c3fSmrg everything that reads from it or writes to it. 45301e04c3fSmrg 45401e04c3fSmrg 2) The bit sizes of all unsized inputs/outputs on any given ALU 45501e04c3fSmrg instruction must match. They need not match the sized inputs or 45601e04c3fSmrg outputs but they must match each other. 45701e04c3fSmrg 45801e04c3fSmrg In order to keep nir_algebraic relatively simple and easy-to-use, 45901e04c3fSmrg nir_search supports a type of bit-size inference based on the two rules 46001e04c3fSmrg above. This is similar to type inference in many common programming 46101e04c3fSmrg languages. If, for instance, you are constructing an add operation and you 46201e04c3fSmrg know the second source is 16-bit, then you know that the other source and 46301e04c3fSmrg the destination must also be 16-bit. There are, however, cases where this 46401e04c3fSmrg inference can be ambiguous or contradictory. Consider, for instance, the 46501e04c3fSmrg following transformation: 46601e04c3fSmrg 4677e102996Smaya (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 46801e04c3fSmrg 46901e04c3fSmrg This transformation can potentially cause a problem because usub_borrow is 47001e04c3fSmrg well-defined for any bit-size of integer. However, b2i always generates a 47101e04c3fSmrg 32-bit result so it could end up replacing a 64-bit expression with one 47201e04c3fSmrg that takes two 64-bit values and produces a 32-bit value. As another 47301e04c3fSmrg example, consider this expression: 47401e04c3fSmrg 47501e04c3fSmrg (('bcsel', a, b, 0), ('iand', a, b)) 47601e04c3fSmrg 47701e04c3fSmrg In this case, in the search expression a must be 32-bit but b can 47801e04c3fSmrg potentially have any bit size. If we had a 64-bit b value, we would end up 47901e04c3fSmrg trying to and a 32-bit value with a 64-bit value which would be invalid 48001e04c3fSmrg 48101e04c3fSmrg This class solves that problem by providing a validation layer that proves 48201e04c3fSmrg that a given search-and-replace operation is 100% well-defined before we 48301e04c3fSmrg generate any code. This ensures that bugs are caught at compile time 48401e04c3fSmrg rather than at run time. 48501e04c3fSmrg 4867e102996Smaya Each value maintains a "bit-size class", which is either an actual bit size 4877e102996Smaya or an equivalence class with other values that must have the same bit size. 4887e102996Smaya The validator works by combining bit-size classes with each other according 4897e102996Smaya to the NIR rules outlined above, checking that there are no inconsistencies. 4907e102996Smaya When doing this for the replacement expression, we make sure to never change 4917e102996Smaya the equivalence class of any of the search values. We could make the example 4927e102996Smaya transforms above work by doing some extra run-time checking of the search 4937e102996Smaya expression, but we make the user specify those constraints themselves, to 4947e102996Smaya avoid any surprises. Since the replacement bitsizes can only be connected to 4957e102996Smaya the source bitsize via variables (variables must have the same bitsize in 4967e102996Smaya the source and replacment expressions) or the roots of the expression (the 4977e102996Smaya replacement expression must produce the same bit size as the search 4987e102996Smaya expression), we prevent merging a variable with anything when processing the 4997e102996Smaya replacement expression, or specializing the search bitsize 5007e102996Smaya with anything. The former prevents 50101e04c3fSmrg 5027e102996Smaya (('bcsel', a, b, 0), ('iand', a, b)) 50301e04c3fSmrg 5047e102996Smaya from being allowed, since we'd have to merge the bitsizes for a and b due to 5057e102996Smaya the 'iand', while the latter prevents 50601e04c3fSmrg 5077e102996Smaya (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 50801e04c3fSmrg 5097e102996Smaya from being allowed, since the search expression has the bit size of a and b, 5107e102996Smaya which can't be specialized to 32 which is the bitsize of the replace 5117e102996Smaya expression. It also prevents something like: 51201e04c3fSmrg 5137e102996Smaya (('b2i', ('i2b', a)), ('ineq', a, 0)) 51401e04c3fSmrg 5157e102996Smaya since the bitsize of 'b2i', which can be anything, can't be specialized to 5167e102996Smaya the bitsize of a. 51701e04c3fSmrg 5187e102996Smaya After doing all this, we check that every subexpression of the replacement 5197e102996Smaya was assigned a constant bitsize, the bitsize of a variable, or the bitsize 5207e102996Smaya of the search expresssion, since those are the things that are known when 5217e102996Smaya constructing the replacement expresssion. Finally, we record the bitsize 5227e102996Smaya needed in nir_search_value so that we know what to do when building the 5237e102996Smaya replacement expression. 5247e102996Smaya """ 52501e04c3fSmrg 5267e102996Smaya def __init__(self, varset): 5277e102996Smaya self._var_classes = [None] * len(varset.names) 5287e102996Smaya 5297e102996Smaya def compare_bitsizes(self, a, b): 5307e102996Smaya """Determines which bitsize class is a specialization of the other, or 5317e102996Smaya whether neither is. When we merge two different bitsizes, the 5327e102996Smaya less-specialized bitsize always points to the more-specialized one, so 5337e102996Smaya that calling get_bit_size() always gets you the most specialized bitsize. 5347e102996Smaya The specialization partial order is given by: 5357e102996Smaya - Physical bitsizes are always the most specialized, and a different 5367e102996Smaya bitsize can never specialize another. 5377e102996Smaya - In the search expression, variables can always be specialized to each 5387e102996Smaya other and to physical bitsizes. In the replace expression, we disallow 5397e102996Smaya this to avoid adding extra constraints to the search expression that 5407e102996Smaya the user didn't specify. 5417e102996Smaya - Expressions and constants without a bitsize can always be specialized to 5427e102996Smaya each other and variables, but not the other way around. 5437e102996Smaya 5447e102996Smaya We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 5457e102996Smaya and None if they are not comparable (neither a <= b nor b <= a). 5467e102996Smaya """ 5477e102996Smaya if isinstance(a, int): 5487e102996Smaya if isinstance(b, int): 5497e102996Smaya return 0 if a == b else None 5507e102996Smaya elif isinstance(b, Variable): 5517e102996Smaya return -1 if self.is_search else None 5527e102996Smaya else: 5537e102996Smaya return -1 5547e102996Smaya elif isinstance(a, Variable): 5557e102996Smaya if isinstance(b, int): 5567e102996Smaya return 1 if self.is_search else None 5577e102996Smaya elif isinstance(b, Variable): 5587e102996Smaya return 0 if self.is_search or a.index == b.index else None 5597e102996Smaya else: 5607e102996Smaya return -1 5617e102996Smaya else: 5627e102996Smaya if isinstance(b, int): 5637e102996Smaya return 1 5647e102996Smaya elif isinstance(b, Variable): 5657e102996Smaya return 1 5667e102996Smaya else: 5677e102996Smaya return 0 5687e102996Smaya 5697e102996Smaya def unify_bit_size(self, a, b, error_msg): 5707e102996Smaya """Record that a must have the same bit-size as b. If both 5717e102996Smaya have been assigned conflicting physical bit-sizes, call "error_msg" with 5727e102996Smaya the bit-sizes of self and other to get a message and raise an error. 5737e102996Smaya In the replace expression, disallow merging variables with other 5747e102996Smaya variables and physical bit-sizes as well. 5757e102996Smaya """ 5767e102996Smaya a_bit_size = a.get_bit_size() 5777e102996Smaya b_bit_size = b if isinstance(b, int) else b.get_bit_size() 5787e102996Smaya 5797e102996Smaya cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 5807e102996Smaya 5817e102996Smaya assert cmp_result is not None, \ 5827e102996Smaya error_msg(a_bit_size, b_bit_size) 5837e102996Smaya 5847e102996Smaya if cmp_result < 0: 5857e102996Smaya b_bit_size.set_bit_size(a) 5867e102996Smaya elif not isinstance(a_bit_size, int): 5877e102996Smaya a_bit_size.set_bit_size(b) 5887e102996Smaya 5897e102996Smaya def merge_variables(self, val): 5907e102996Smaya """Perform the first part of type inference by merging all the different 5917e102996Smaya uses of the same variable. We always do this as if we're in the search 5927e102996Smaya expression, even if we're actually not, since otherwise we'd get errors 5937e102996Smaya if the search expression specified some constraint but the replace 5947e102996Smaya expression didn't, because we'd be merging a variable and a constant. 5957e102996Smaya """ 5967e102996Smaya if isinstance(val, Variable): 5977e102996Smaya if self._var_classes[val.index] is None: 5987e102996Smaya self._var_classes[val.index] = val 5997e102996Smaya else: 6007e102996Smaya other = self._var_classes[val.index] 6017e102996Smaya self.unify_bit_size(other, val, 6027e102996Smaya lambda other_bit_size, bit_size: 6037e102996Smaya 'Variable {} has conflicting bit size requirements: ' \ 6047e102996Smaya 'it must have bit size {} and {}'.format( 6057e102996Smaya val.var_name, other_bit_size, bit_size)) 60601e04c3fSmrg elif isinstance(val, Expression): 6077e102996Smaya for src in val.sources: 6087e102996Smaya self.merge_variables(src) 6097e102996Smaya 6107e102996Smaya def validate_value(self, val): 6117e102996Smaya """Validate the an expression by performing classic Hindley-Milner 6127e102996Smaya type inference on bitsizes. This will detect if there are any conflicting 6137e102996Smaya requirements, and unify variables so that we know which variables must 6147e102996Smaya have the same bitsize. If we're operating on the replace expression, we 6157e102996Smaya will refuse to merge different variables together or merge a variable 6167e102996Smaya with a constant, in order to prevent surprises due to rules unexpectedly 6177e102996Smaya not matching at runtime. 6187e102996Smaya """ 6197e102996Smaya if not isinstance(val, Expression): 6207e102996Smaya return 6217e102996Smaya 6227e102996Smaya # Generic conversion ops are special in that they have a single unsized 6237e102996Smaya # source and an unsized destination and the two don't have to match. 6247e102996Smaya # This means there's no validation or unioning to do here besides the 6257e102996Smaya # len(val.sources) check. 6267e102996Smaya if val.opcode in conv_opcode_types: 6277e102996Smaya assert len(val.sources) == 1, \ 6287e102996Smaya "Expression {} has {} sources, expected 1".format( 6297e102996Smaya val, len(val.sources)) 6307e102996Smaya self.validate_value(val.sources[0]) 6317e102996Smaya return 6327e102996Smaya 6337e102996Smaya nir_op = opcodes[val.opcode] 6347e102996Smaya assert len(val.sources) == nir_op.num_inputs, \ 6357e102996Smaya "Expression {} has {} sources, expected {}".format( 6367e102996Smaya val, len(val.sources), nir_op.num_inputs) 6377e102996Smaya 6387e102996Smaya for src in val.sources: 6397e102996Smaya self.validate_value(src) 6407e102996Smaya 6417e102996Smaya dst_type_bits = type_bits(nir_op.output_type) 6427e102996Smaya 6437e102996Smaya # First, unify all the sources. That way, an error coming up because two 6447e102996Smaya # sources have an incompatible bit-size won't produce an error message 6457e102996Smaya # involving the destination. 6467e102996Smaya first_unsized_src = None 6477e102996Smaya for src_type, src in zip(nir_op.input_types, val.sources): 6487e102996Smaya src_type_bits = type_bits(src_type) 6497e102996Smaya if src_type_bits == 0: 6507e102996Smaya if first_unsized_src is None: 6517e102996Smaya first_unsized_src = src 65201e04c3fSmrg continue 65301e04c3fSmrg 6547e102996Smaya if self.is_search: 6557e102996Smaya self.unify_bit_size(first_unsized_src, src, 6567e102996Smaya lambda first_unsized_src_bit_size, src_bit_size: 6577e102996Smaya 'Source {} of {} must have bit size {}, while source {} ' \ 6587e102996Smaya 'must have incompatible bit size {}'.format( 6597e102996Smaya first_unsized_src, val, first_unsized_src_bit_size, 6607e102996Smaya src, src_bit_size)) 66101e04c3fSmrg else: 6627e102996Smaya self.unify_bit_size(first_unsized_src, src, 6637e102996Smaya lambda first_unsized_src_bit_size, src_bit_size: 6647e102996Smaya 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 6657e102996Smaya 'of {} may not have the same bit size when building the ' \ 6667e102996Smaya 'replacement expression.'.format( 6677e102996Smaya first_unsized_src, first_unsized_src_bit_size, src, 6687e102996Smaya src_bit_size, val)) 66901e04c3fSmrg else: 6707e102996Smaya if self.is_search: 6717e102996Smaya self.unify_bit_size(src, src_type_bits, 6727e102996Smaya lambda src_bit_size, unused: 6737e102996Smaya '{} must have {} bits, but as a source of nir_op_{} '\ 6747e102996Smaya 'it must have {} bits'.format( 6757e102996Smaya src, src_bit_size, nir_op.name, src_type_bits)) 6767e102996Smaya else: 6777e102996Smaya self.unify_bit_size(src, src_type_bits, 6787e102996Smaya lambda src_bit_size, unused: 6797e102996Smaya '{} has the bit size of {}, but as a source of ' \ 6807e102996Smaya 'nir_op_{} it must have {} bits, which may not be the ' \ 6817e102996Smaya 'same'.format( 6827e102996Smaya src, src_bit_size, nir_op.name, src_type_bits)) 6837e102996Smaya 6847e102996Smaya if dst_type_bits == 0: 6857e102996Smaya if first_unsized_src is not None: 6867e102996Smaya if self.is_search: 6877e102996Smaya self.unify_bit_size(val, first_unsized_src, 6887e102996Smaya lambda val_bit_size, src_bit_size: 6897e102996Smaya '{} must have the bit size of {}, while its source {} ' \ 6907e102996Smaya 'must have incompatible bit size {}'.format( 6917e102996Smaya val, val_bit_size, first_unsized_src, src_bit_size)) 69201e04c3fSmrg else: 6937e102996Smaya self.unify_bit_size(val, first_unsized_src, 6947e102996Smaya lambda val_bit_size, src_bit_size: 6957e102996Smaya '{} must have {} bits, but its source {} ' \ 6967e102996Smaya '(bit size of {}) may not have that bit size ' \ 6977e102996Smaya 'when building the replacement.'.format( 6987e102996Smaya val, val_bit_size, first_unsized_src, src_bit_size)) 6997e102996Smaya else: 7007e102996Smaya self.unify_bit_size(val, dst_type_bits, 7017e102996Smaya lambda dst_bit_size, unused: 7027e102996Smaya '{} must have {} bits, but as a destination of nir_op_{} ' \ 7037e102996Smaya 'it must have {} bits'.format( 7047e102996Smaya val, dst_bit_size, nir_op.name, dst_type_bits)) 7057e102996Smaya 7067e102996Smaya def validate_replace(self, val, search): 7077e102996Smaya bit_size = val.get_bit_size() 7087e102996Smaya assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 7097e102996Smaya bit_size == search.get_bit_size(), \ 7107e102996Smaya 'Ambiguous bit size for replacement value {}: ' \ 7117e102996Smaya 'it cannot be deduced from a variable, a fixed bit size ' \ 7127e102996Smaya 'somewhere, or the search expression.'.format(val) 7137e102996Smaya 7147e102996Smaya if isinstance(val, Expression): 7157e102996Smaya for src in val.sources: 7167e102996Smaya self.validate_replace(src, search) 7177e102996Smaya 7187e102996Smaya def validate(self, search, replace): 7197e102996Smaya self.is_search = True 7207e102996Smaya self.merge_variables(search) 7217e102996Smaya self.merge_variables(replace) 7227e102996Smaya self.validate_value(search) 72301e04c3fSmrg 7247e102996Smaya self.is_search = False 7257e102996Smaya self.validate_value(replace) 72601e04c3fSmrg 7277e102996Smaya # Check that search is always more specialized than replace. Note that 7287e102996Smaya # we're doing this in replace mode, disallowing merging variables. 7297e102996Smaya search_bit_size = search.get_bit_size() 7307e102996Smaya replace_bit_size = replace.get_bit_size() 7317e102996Smaya cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 73201e04c3fSmrg 7337e102996Smaya assert cmp_result is not None and cmp_result <= 0, \ 7347e102996Smaya 'The search expression bit size {} and replace expression ' \ 7357e102996Smaya 'bit size {} may not be the same'.format( 7367e102996Smaya search_bit_size, replace_bit_size) 73701e04c3fSmrg 7387e102996Smaya replace.set_bit_size(search) 73901e04c3fSmrg 7407e102996Smaya self.validate_replace(replace, search) 74101e04c3fSmrg 74201e04c3fSmrg_optimization_ids = itertools.count() 74301e04c3fSmrg 74401e04c3fSmrgcondition_list = ['true'] 74501e04c3fSmrg 74601e04c3fSmrgclass SearchAndReplace(object): 74701e04c3fSmrg def __init__(self, transform): 74801e04c3fSmrg self.id = next(_optimization_ids) 74901e04c3fSmrg 75001e04c3fSmrg search = transform[0] 75101e04c3fSmrg replace = transform[1] 75201e04c3fSmrg if len(transform) > 2: 75301e04c3fSmrg self.condition = transform[2] 75401e04c3fSmrg else: 75501e04c3fSmrg self.condition = 'true' 75601e04c3fSmrg 75701e04c3fSmrg if self.condition not in condition_list: 75801e04c3fSmrg condition_list.append(self.condition) 75901e04c3fSmrg self.condition_index = condition_list.index(self.condition) 76001e04c3fSmrg 76101e04c3fSmrg varset = VarSet() 76201e04c3fSmrg if isinstance(search, Expression): 76301e04c3fSmrg self.search = search 76401e04c3fSmrg else: 76501e04c3fSmrg self.search = Expression(search, "search{0}".format(self.id), varset) 76601e04c3fSmrg 76701e04c3fSmrg varset.lock() 76801e04c3fSmrg 76901e04c3fSmrg if isinstance(replace, Value): 77001e04c3fSmrg self.replace = replace 77101e04c3fSmrg else: 77201e04c3fSmrg self.replace = Value.create(replace, "replace{0}".format(self.id), varset) 77301e04c3fSmrg 77401e04c3fSmrg BitSizeValidator(varset).validate(self.search, self.replace) 77501e04c3fSmrg 7767e102996Smayaclass TreeAutomaton(object): 7777e102996Smaya """This class calculates a bottom-up tree automaton to quickly search for 7787e102996Smaya the left-hand sides of tranforms. Tree automatons are a generalization of 7797e102996Smaya classical NFA's and DFA's, where the transition function determines the 7807e102996Smaya state of the parent node based on the state of its children. We construct a 7817e102996Smaya deterministic automaton to match patterns, using a similar algorithm to the 7827e102996Smaya classical NFA to DFA construction. At the moment, it only matches opcodes 7837e102996Smaya and constants (without checking the actual value), leaving more detailed 7847e102996Smaya checking to the search function which actually checks the leaves. The 7857e102996Smaya automaton acts as a quick filter for the search function, requiring only n 7867e102996Smaya + 1 table lookups for each n-source operation. The implementation is based 7877e102996Smaya on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 7887e102996Smaya In the language of that reference, this is a frontier-to-root deterministic 7897e102996Smaya automaton using only symbol filtering. The filtering is crucial to reduce 7907e102996Smaya both the time taken to generate the tables and the size of the tables. 7917e102996Smaya """ 7927e102996Smaya def __init__(self, transforms): 7937e102996Smaya self.patterns = [t.search for t in transforms] 7947e102996Smaya self._compute_items() 7957e102996Smaya self._build_table() 7967e102996Smaya #print('num items: {}'.format(len(set(self.items.values())))) 7977e102996Smaya #print('num states: {}'.format(len(self.states))) 7987e102996Smaya #for state, patterns in zip(self.states, self.patterns): 7997e102996Smaya # print('{}: num patterns: {}'.format(state, len(patterns))) 8007e102996Smaya 8017e102996Smaya class IndexMap(object): 8027e102996Smaya """An indexed list of objects, where one can either lookup an object by 8037e102996Smaya index or find the index associated to an object quickly using a hash 8047e102996Smaya table. Compared to a list, it has a constant time index(). Compared to a 8057e102996Smaya set, it provides a stable iteration order. 8067e102996Smaya """ 8077e102996Smaya def __init__(self, iterable=()): 8087e102996Smaya self.objects = [] 8097e102996Smaya self.map = {} 8107e102996Smaya for obj in iterable: 8117e102996Smaya self.add(obj) 8127e102996Smaya 8137e102996Smaya def __getitem__(self, i): 8147e102996Smaya return self.objects[i] 8157e102996Smaya 8167e102996Smaya def __contains__(self, obj): 8177e102996Smaya return obj in self.map 8187e102996Smaya 8197e102996Smaya def __len__(self): 8207e102996Smaya return len(self.objects) 8217e102996Smaya 8227e102996Smaya def __iter__(self): 8237e102996Smaya return iter(self.objects) 8247e102996Smaya 8257e102996Smaya def clear(self): 8267e102996Smaya self.objects = [] 8277e102996Smaya self.map.clear() 8287e102996Smaya 8297e102996Smaya def index(self, obj): 8307e102996Smaya return self.map[obj] 8317e102996Smaya 8327e102996Smaya def add(self, obj): 8337e102996Smaya if obj in self.map: 8347e102996Smaya return self.map[obj] 8357e102996Smaya else: 8367e102996Smaya index = len(self.objects) 8377e102996Smaya self.objects.append(obj) 8387e102996Smaya self.map[obj] = index 8397e102996Smaya return index 8407e102996Smaya 8417e102996Smaya def __repr__(self): 8427e102996Smaya return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 8437e102996Smaya 8447e102996Smaya class Item(object): 8457e102996Smaya """This represents an "item" in the language of "Tree Automatons." This 8467e102996Smaya is just a subtree of some pattern, which represents a potential partial 8477e102996Smaya match at runtime. We deduplicate them, so that identical subtrees of 8487e102996Smaya different patterns share the same object, and store some extra 8497e102996Smaya information needed for the main algorithm as well. 8507e102996Smaya """ 8517e102996Smaya def __init__(self, opcode, children): 8527e102996Smaya self.opcode = opcode 8537e102996Smaya self.children = children 8547e102996Smaya # These are the indices of patterns for which this item is the root node. 8557e102996Smaya self.patterns = [] 8567e102996Smaya # This the set of opcodes for parents of this item. Used to speed up 8577e102996Smaya # filtering. 8587e102996Smaya self.parent_ops = set() 8597e102996Smaya 8607e102996Smaya def __str__(self): 8617e102996Smaya return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 8627e102996Smaya 8637e102996Smaya def __repr__(self): 8647e102996Smaya return str(self) 8657e102996Smaya 8667e102996Smaya def _compute_items(self): 8677e102996Smaya """Build a set of all possible items, deduplicating them.""" 8687e102996Smaya # This is a map from (opcode, sources) to item. 8697e102996Smaya self.items = {} 8707e102996Smaya 8717e102996Smaya # The set of all opcodes used by the patterns. Used later to avoid 8727e102996Smaya # building and emitting all the tables for opcodes that aren't used. 8737e102996Smaya self.opcodes = self.IndexMap() 8747e102996Smaya 8757e102996Smaya def get_item(opcode, children, pattern=None): 8767ec681f3Smrg commutative = len(children) >= 2 \ 8777ec681f3Smrg and "2src_commutative" in opcodes[opcode].algebraic_properties 8787e102996Smaya item = self.items.setdefault((opcode, children), 8797e102996Smaya self.Item(opcode, children)) 8807e102996Smaya if commutative: 8817ec681f3Smrg self.items[opcode, (children[1], children[0]) + children[2:]] = item 8827e102996Smaya if pattern is not None: 8837e102996Smaya item.patterns.append(pattern) 8847e102996Smaya return item 8857e102996Smaya 8867e102996Smaya self.wildcard = get_item("__wildcard", ()) 8877e102996Smaya self.const = get_item("__const", ()) 8887e102996Smaya 8897e102996Smaya def process_subpattern(src, pattern=None): 8907e102996Smaya if isinstance(src, Constant): 8917e102996Smaya # Note: we throw away the actual constant value! 8927e102996Smaya return self.const 8937e102996Smaya elif isinstance(src, Variable): 8947e102996Smaya if src.is_constant: 8957e102996Smaya return self.const 8967e102996Smaya else: 8977e102996Smaya # Note: we throw away which variable it is here! This special 8987e102996Smaya # item is equivalent to nu in "Tree Automatons." 8997e102996Smaya return self.wildcard 9007e102996Smaya else: 9017e102996Smaya assert isinstance(src, Expression) 9027e102996Smaya opcode = src.opcode 9037e102996Smaya stripped = opcode.rstrip('0123456789') 9047e102996Smaya if stripped in conv_opcode_types: 9057e102996Smaya # Matches that use conversion opcodes with a specific type, 9067e102996Smaya # like f2b1, are tricky. Either we construct the automaton to 9077e102996Smaya # match specific NIR opcodes like nir_op_f2b1, in which case we 9087e102996Smaya # need to create separate items for each possible NIR opcode 9097e102996Smaya # for patterns that have a generic opcode like f2b, or we 9107e102996Smaya # construct it to match the search opcode, in which case we 9117e102996Smaya # need to map f2b1 to f2b when constructing the automaton. Here 9127e102996Smaya # we do the latter. 9137e102996Smaya opcode = stripped 9147e102996Smaya self.opcodes.add(opcode) 9157e102996Smaya children = tuple(process_subpattern(c) for c in src.sources) 9167e102996Smaya item = get_item(opcode, children, pattern) 9177e102996Smaya for i, child in enumerate(children): 9187e102996Smaya child.parent_ops.add(opcode) 9197e102996Smaya return item 9207e102996Smaya 9217e102996Smaya for i, pattern in enumerate(self.patterns): 9227e102996Smaya process_subpattern(pattern, i) 9237e102996Smaya 9247e102996Smaya def _build_table(self): 9257e102996Smaya """This is the core algorithm which builds up the transition table. It 9267e102996Smaya is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 9277e102996Smaya Comp_a and Filt_{a,i} using integers to identify match sets." It 9287e102996Smaya simultaneously builds up a list of all possible "match sets" or 9297e102996Smaya "states", where each match set represents the set of Item's that match a 9307e102996Smaya given instruction, and builds up the transition table between states. 9317e102996Smaya """ 9327e102996Smaya # Map from opcode + filtered state indices to transitioned state. 9337e102996Smaya self.table = defaultdict(dict) 9347e102996Smaya # Bijection from state to index. q in the original algorithm is 9357e102996Smaya # len(self.states) 9367e102996Smaya self.states = self.IndexMap() 9377e102996Smaya # List of pattern matches for each state index. 9387e102996Smaya self.state_patterns = [] 9397e102996Smaya # Map from state index to filtered state index for each opcode. 9407e102996Smaya self.filter = defaultdict(list) 9417e102996Smaya # Bijections from filtered state to filtered state index for each 9427e102996Smaya # opcode, called the "representor sets" in the original algorithm. 9437e102996Smaya # q_{a,j} in the original algorithm is len(self.rep[op]). 9447e102996Smaya self.rep = defaultdict(self.IndexMap) 9457e102996Smaya 9467e102996Smaya # Everything in self.states with a index at least worklist_index is part 9477e102996Smaya # of the worklist of newly created states. There is also a worklist of 9487e102996Smaya # newly fitered states for each opcode, for which worklist_indices 9497e102996Smaya # serves a similar purpose. worklist_index corresponds to p in the 9507e102996Smaya # original algorithm, while worklist_indices is p_{a,j} (although since 9517e102996Smaya # we only filter by opcode/symbol, it's really just p_a). 9527e102996Smaya self.worklist_index = 0 9537e102996Smaya worklist_indices = defaultdict(lambda: 0) 9547e102996Smaya 9557e102996Smaya # This is the set of opcodes for which the filtered worklist is non-empty. 9567e102996Smaya # It's used to avoid scanning opcodes for which there is nothing to 9577e102996Smaya # process when building the transition table. It corresponds to new_a in 9587e102996Smaya # the original algorithm. 9597e102996Smaya new_opcodes = self.IndexMap() 9607e102996Smaya 9617e102996Smaya # Process states on the global worklist, filtering them for each opcode, 9627e102996Smaya # updating the filter tables, and updating the filtered worklists if any 9637e102996Smaya # new filtered states are found. Similar to ComputeRepresenterSets() in 9647e102996Smaya # the original algorithm, although that only processes a single state. 9657e102996Smaya def process_new_states(): 9667e102996Smaya while self.worklist_index < len(self.states): 9677e102996Smaya state = self.states[self.worklist_index] 9687e102996Smaya 9697e102996Smaya # Calculate pattern matches for this state. Each pattern is 9707e102996Smaya # assigned to a unique item, so we don't have to worry about 9717e102996Smaya # deduplicating them here. However, we do have to sort them so 9727e102996Smaya # that they're visited at runtime in the order they're specified 9737e102996Smaya # in the source. 9747e102996Smaya patterns = list(sorted(p for item in state for p in item.patterns)) 9757e102996Smaya assert len(self.state_patterns) == self.worklist_index 9767e102996Smaya self.state_patterns.append(patterns) 9777e102996Smaya 9787e102996Smaya # calculate filter table for this state, and update filtered 9797e102996Smaya # worklists. 9807e102996Smaya for op in self.opcodes: 9817e102996Smaya filt = self.filter[op] 9827e102996Smaya rep = self.rep[op] 9837e102996Smaya filtered = frozenset(item for item in state if \ 9847e102996Smaya op in item.parent_ops) 9857e102996Smaya if filtered in rep: 9867e102996Smaya rep_index = rep.index(filtered) 9877e102996Smaya else: 9887e102996Smaya rep_index = rep.add(filtered) 9897e102996Smaya new_opcodes.add(op) 9907e102996Smaya assert len(filt) == self.worklist_index 9917e102996Smaya filt.append(rep_index) 9927e102996Smaya self.worklist_index += 1 9937e102996Smaya 9947e102996Smaya # There are two start states: one which can only match as a wildcard, 9957e102996Smaya # and one which can match as a wildcard or constant. These will be the 9967e102996Smaya # states of intrinsics/other instructions and load_const instructions, 9977e102996Smaya # respectively. The indices of these must match the definitions of 9987e102996Smaya # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 9997e102996Smaya # initialize things correctly. 10007e102996Smaya self.states.add(frozenset((self.wildcard,))) 10017e102996Smaya self.states.add(frozenset((self.const,self.wildcard))) 10027e102996Smaya process_new_states() 10037e102996Smaya 10047e102996Smaya while len(new_opcodes) > 0: 10057e102996Smaya for op in new_opcodes: 10067e102996Smaya rep = self.rep[op] 10077e102996Smaya table = self.table[op] 10087e102996Smaya op_worklist_index = worklist_indices[op] 10097e102996Smaya if op in conv_opcode_types: 10107e102996Smaya num_srcs = 1 10117e102996Smaya else: 10127e102996Smaya num_srcs = opcodes[op].num_inputs 10137e102996Smaya 10147e102996Smaya # Iterate over all possible source combinations where at least one 10157e102996Smaya # is on the worklist. 10167e102996Smaya for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 10177e102996Smaya if all(src_idx < op_worklist_index for src_idx in src_indices): 10187e102996Smaya continue 10197e102996Smaya 10207e102996Smaya srcs = tuple(rep[src_idx] for src_idx in src_indices) 10217e102996Smaya 10227e102996Smaya # Try all possible pairings of source items and add the 10237e102996Smaya # corresponding parent items. This is Comp_a from the paper. 10247e102996Smaya parent = set(self.items[op, item_srcs] for item_srcs in 10257e102996Smaya itertools.product(*srcs) if (op, item_srcs) in self.items) 10267e102996Smaya 10277e102996Smaya # We could always start matching something else with a 10287e102996Smaya # wildcard. This is Cl from the paper. 10297e102996Smaya parent.add(self.wildcard) 10307e102996Smaya 10317e102996Smaya table[src_indices] = self.states.add(frozenset(parent)) 10327e102996Smaya worklist_indices[op] = len(rep) 10337e102996Smaya new_opcodes.clear() 10347e102996Smaya process_new_states() 10357e102996Smaya 103601e04c3fSmrg_algebraic_pass_template = mako.template.Template(""" 103701e04c3fSmrg#include "nir.h" 103801e04c3fSmrg#include "nir_builder.h" 103901e04c3fSmrg#include "nir_search.h" 104001e04c3fSmrg#include "nir_search_helpers.h" 104101e04c3fSmrg 10427ec681f3Smrg/* What follows is NIR algebraic transform code for the following ${len(xforms)} 10437ec681f3Smrg * transforms: 10447ec681f3Smrg% for xform in xforms: 10457ec681f3Smrg * ${xform.search} => ${xform.replace} 10467ec681f3Smrg% endfor 10477e102996Smaya */ 10487e102996Smaya 10497e102996Smaya<% cache = {} %> 10507e102996Smaya% for xform in xforms: 10517e102996Smaya ${xform.search.render(cache)} 10527e102996Smaya ${xform.replace.render(cache)} 105301e04c3fSmrg% endfor 105401e04c3fSmrg 10557e102996Smaya% for state_id, state_xforms in enumerate(automaton.state_patterns): 10567e102996Smaya% if state_xforms: # avoid emitting a 0-length array for MSVC 10577e102996Smayastatic const struct transform ${pass_name}_state${state_id}_xforms[] = { 10587e102996Smaya% for i in state_xforms: 10597e102996Smaya { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} }, 106001e04c3fSmrg% endfor 106101e04c3fSmrg}; 10627e102996Smaya% endif 106301e04c3fSmrg% endfor 106401e04c3fSmrg 10657e102996Smayastatic const struct per_op_table ${pass_name}_table[nir_num_search_ops] = { 10667e102996Smaya% for op in automaton.opcodes: 10677e102996Smaya [${get_c_opcode(op)}] = { 10687e102996Smaya .filter = (uint16_t []) { 10697e102996Smaya % for e in automaton.filter[op]: 10707e102996Smaya ${e}, 10717e102996Smaya % endfor 10727e102996Smaya }, 10737e102996Smaya <% 10747e102996Smaya num_filtered = len(automaton.rep[op]) 10757e102996Smaya %> 10767e102996Smaya .num_filtered_states = ${num_filtered}, 10777e102996Smaya .table = (uint16_t []) { 10787e102996Smaya <% 10797e102996Smaya num_srcs = len(next(iter(automaton.table[op]))) 10807e102996Smaya %> 10817e102996Smaya % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 10827e102996Smaya ${automaton.table[op][indices]}, 10837e102996Smaya % endfor 10847e102996Smaya }, 10857e102996Smaya }, 10867e102996Smaya% endfor 10877e102996Smaya}; 10887e102996Smaya 10897ec681f3Smrgconst struct transform *${pass_name}_transforms[] = { 10907e102996Smaya% for i in range(len(automaton.state_patterns)): 10917ec681f3Smrg % if automaton.state_patterns[i]: 10927ec681f3Smrg ${pass_name}_state${i}_xforms, 10937ec681f3Smrg % else: 10947ec681f3Smrg NULL, 10957ec681f3Smrg % endif 10967e102996Smaya% endfor 10977ec681f3Smrg}; 109801e04c3fSmrg 10997ec681f3Smrgconst uint16_t ${pass_name}_transform_counts[] = { 11007ec681f3Smrg% for i in range(len(automaton.state_patterns)): 11017ec681f3Smrg % if automaton.state_patterns[i]: 11027ec681f3Smrg (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms), 11037ec681f3Smrg % else: 11047ec681f3Smrg 0, 11057ec681f3Smrg % endif 11067ec681f3Smrg% endfor 11077ec681f3Smrg}; 110801e04c3fSmrg 110901e04c3fSmrgbool 111001e04c3fSmrg${pass_name}(nir_shader *shader) 111101e04c3fSmrg{ 111201e04c3fSmrg bool progress = false; 111301e04c3fSmrg bool condition_flags[${len(condition_list)}]; 111401e04c3fSmrg const nir_shader_compiler_options *options = shader->options; 11157e102996Smaya const shader_info *info = &shader->info; 111601e04c3fSmrg (void) options; 11177e102996Smaya (void) info; 111801e04c3fSmrg 111901e04c3fSmrg % for index, condition in enumerate(condition_list): 112001e04c3fSmrg condition_flags[${index}] = ${condition}; 112101e04c3fSmrg % endfor 112201e04c3fSmrg 112301e04c3fSmrg nir_foreach_function(function, shader) { 11247ec681f3Smrg if (function->impl) { 11257ec681f3Smrg progress |= nir_algebraic_impl(function->impl, condition_flags, 11267ec681f3Smrg ${pass_name}_transforms, 11277ec681f3Smrg ${pass_name}_transform_counts, 11287ec681f3Smrg ${pass_name}_table); 11297ec681f3Smrg } 113001e04c3fSmrg } 113101e04c3fSmrg 113201e04c3fSmrg return progress; 113301e04c3fSmrg} 113401e04c3fSmrg""") 113501e04c3fSmrg 11367e102996Smaya 113701e04c3fSmrgclass AlgebraicPass(object): 113801e04c3fSmrg def __init__(self, pass_name, transforms): 11397e102996Smaya self.xforms = [] 11407e102996Smaya self.opcode_xforms = defaultdict(lambda : []) 114101e04c3fSmrg self.pass_name = pass_name 114201e04c3fSmrg 114301e04c3fSmrg error = False 114401e04c3fSmrg 114501e04c3fSmrg for xform in transforms: 114601e04c3fSmrg if not isinstance(xform, SearchAndReplace): 114701e04c3fSmrg try: 114801e04c3fSmrg xform = SearchAndReplace(xform) 114901e04c3fSmrg except: 115001e04c3fSmrg print("Failed to parse transformation:", file=sys.stderr) 115101e04c3fSmrg print(" " + str(xform), file=sys.stderr) 115201e04c3fSmrg traceback.print_exc(file=sys.stderr) 115301e04c3fSmrg print('', file=sys.stderr) 115401e04c3fSmrg error = True 115501e04c3fSmrg continue 115601e04c3fSmrg 11577e102996Smaya self.xforms.append(xform) 11587e102996Smaya if xform.search.opcode in conv_opcode_types: 11597e102996Smaya dst_type = conv_opcode_types[xform.search.opcode] 11607e102996Smaya for size in type_sizes(dst_type): 11617e102996Smaya sized_opcode = xform.search.opcode + str(size) 11627e102996Smaya self.opcode_xforms[sized_opcode].append(xform) 11637e102996Smaya else: 11647e102996Smaya self.opcode_xforms[xform.search.opcode].append(xform) 116501e04c3fSmrg 11667ec681f3Smrg # Check to make sure the search pattern does not unexpectedly contain 11677ec681f3Smrg # more commutative expressions than match_expression (nir_search.c) 11687ec681f3Smrg # can handle. 11697ec681f3Smrg comm_exprs = xform.search.comm_exprs 11707ec681f3Smrg 11717ec681f3Smrg if xform.search.many_commutative_expressions: 11727ec681f3Smrg if comm_exprs <= nir_search_max_comm_ops: 11737ec681f3Smrg print("Transform expected to have too many commutative " \ 11747ec681f3Smrg "expression but did not " \ 11757ec681f3Smrg "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 11767ec681f3Smrg file=sys.stderr) 11777ec681f3Smrg print(" " + str(xform), file=sys.stderr) 11787ec681f3Smrg traceback.print_exc(file=sys.stderr) 11797ec681f3Smrg print('', file=sys.stderr) 11807ec681f3Smrg error = True 11817ec681f3Smrg else: 11827ec681f3Smrg if comm_exprs > nir_search_max_comm_ops: 11837ec681f3Smrg print("Transformation with too many commutative expressions " \ 11847ec681f3Smrg "({} > {}). Modify pattern or annotate with " \ 11857ec681f3Smrg "\"many-comm-expr\".".format(comm_exprs, 11867ec681f3Smrg nir_search_max_comm_ops), 11877ec681f3Smrg file=sys.stderr) 11887ec681f3Smrg print(" " + str(xform.search), file=sys.stderr) 11897ec681f3Smrg print("{}".format(xform.search.cond), file=sys.stderr) 11907ec681f3Smrg error = True 11917ec681f3Smrg 11927e102996Smaya self.automaton = TreeAutomaton(self.xforms) 119301e04c3fSmrg 119401e04c3fSmrg if error: 119501e04c3fSmrg sys.exit(1) 119601e04c3fSmrg 11977e102996Smaya 119801e04c3fSmrg def render(self): 119901e04c3fSmrg return _algebraic_pass_template.render(pass_name=self.pass_name, 12007e102996Smaya xforms=self.xforms, 12017e102996Smaya opcode_xforms=self.opcode_xforms, 12027e102996Smaya condition_list=condition_list, 12037e102996Smaya automaton=self.automaton, 12047e102996Smaya get_c_opcode=get_c_opcode, 12057e102996Smaya itertools=itertools) 1206