1b8e80941Smrg# 2b8e80941Smrg# Copyright (C) 2014 Intel Corporation 3b8e80941Smrg# 4b8e80941Smrg# Permission is hereby granted, free of charge, to any person obtaining a 5b8e80941Smrg# copy of this software and associated documentation files (the "Software"), 6b8e80941Smrg# to deal in the Software without restriction, including without limitation 7b8e80941Smrg# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8b8e80941Smrg# and/or sell copies of the Software, and to permit persons to whom the 9b8e80941Smrg# Software is furnished to do so, subject to the following conditions: 10b8e80941Smrg# 11b8e80941Smrg# The above copyright notice and this permission notice (including the next 12b8e80941Smrg# paragraph) shall be included in all copies or substantial portions of the 13b8e80941Smrg# Software. 14b8e80941Smrg# 15b8e80941Smrg# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16b8e80941Smrg# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17b8e80941Smrg# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18b8e80941Smrg# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19b8e80941Smrg# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20b8e80941Smrg# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21b8e80941Smrg# IN THE SOFTWARE. 22b8e80941Smrg# 23b8e80941Smrg# Authors: 24b8e80941Smrg# Jason Ekstrand (jason@jlekstrand.net) 25b8e80941Smrg 26b8e80941Smrgfrom __future__ import print_function 27b8e80941Smrgimport ast 28b8e80941Smrgfrom collections import defaultdict 29b8e80941Smrgimport itertools 30b8e80941Smrgimport struct 31b8e80941Smrgimport sys 32b8e80941Smrgimport mako.template 33b8e80941Smrgimport re 34b8e80941Smrgimport traceback 35b8e80941Smrg 36b8e80941Smrgfrom nir_opcodes import opcodes, type_sizes 37b8e80941Smrg 38b8e80941Smrg# These opcodes are only employed by nir_search. This provides a mapping from 39b8e80941Smrg# opcode to destination type. 40b8e80941Smrgconv_opcode_types = { 41b8e80941Smrg 'i2f' : 'float', 42b8e80941Smrg 'u2f' : 'float', 43b8e80941Smrg 'f2f' : 'float', 44b8e80941Smrg 'f2u' : 'uint', 45b8e80941Smrg 'f2i' : 'int', 46b8e80941Smrg 'u2u' : 'uint', 47b8e80941Smrg 'i2i' : 'int', 48b8e80941Smrg 'b2f' : 'float', 49b8e80941Smrg 'b2i' : 'int', 50b8e80941Smrg 'i2b' : 'bool', 51b8e80941Smrg 'f2b' : 'bool', 52b8e80941Smrg} 53b8e80941Smrg 54b8e80941Smrgdef get_c_opcode(op): 55b8e80941Smrg if op in conv_opcode_types: 56b8e80941Smrg return 'nir_search_op_' + op 57b8e80941Smrg else: 58b8e80941Smrg return 'nir_op_' + op 59b8e80941Smrg 60b8e80941Smrg 61b8e80941Smrgif sys.version_info < (3, 0): 62b8e80941Smrg integer_types = (int, long) 63b8e80941Smrg string_type = unicode 64b8e80941Smrg 65b8e80941Smrgelse: 66b8e80941Smrg integer_types = (int, ) 67b8e80941Smrg string_type = str 68b8e80941Smrg 69b8e80941Smrg_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 70b8e80941Smrg 71b8e80941Smrgdef type_bits(type_str): 72b8e80941Smrg m = _type_re.match(type_str) 73b8e80941Smrg assert m.group('type') 74b8e80941Smrg 75b8e80941Smrg if m.group('bits') is None: 76b8e80941Smrg return 0 77b8e80941Smrg else: 78b8e80941Smrg return int(m.group('bits')) 79b8e80941Smrg 80b8e80941Smrg# Represents a set of variables, each with a unique id 81b8e80941Smrgclass VarSet(object): 82b8e80941Smrg def __init__(self): 83b8e80941Smrg self.names = {} 84b8e80941Smrg self.ids = itertools.count() 85b8e80941Smrg self.immutable = False; 86b8e80941Smrg 87b8e80941Smrg def __getitem__(self, name): 88b8e80941Smrg if name not in self.names: 89b8e80941Smrg assert not self.immutable, "Unknown replacement variable: " + name 90b8e80941Smrg self.names[name] = next(self.ids) 91b8e80941Smrg 92b8e80941Smrg return self.names[name] 93b8e80941Smrg 94b8e80941Smrg def lock(self): 95b8e80941Smrg self.immutable = True 96b8e80941Smrg 97b8e80941Smrgclass Value(object): 98b8e80941Smrg @staticmethod 99b8e80941Smrg def create(val, name_base, varset): 100b8e80941Smrg if isinstance(val, bytes): 101b8e80941Smrg val = val.decode('utf-8') 102b8e80941Smrg 103b8e80941Smrg if isinstance(val, tuple): 104b8e80941Smrg return Expression(val, name_base, varset) 105b8e80941Smrg elif isinstance(val, Expression): 106b8e80941Smrg return val 107b8e80941Smrg elif isinstance(val, string_type): 108b8e80941Smrg return Variable(val, name_base, varset) 109b8e80941Smrg elif isinstance(val, (bool, float) + integer_types): 110b8e80941Smrg return Constant(val, name_base) 111b8e80941Smrg 112b8e80941Smrg def __init__(self, val, name, type_str): 113b8e80941Smrg self.in_val = str(val) 114b8e80941Smrg self.name = name 115b8e80941Smrg self.type_str = type_str 116b8e80941Smrg 117b8e80941Smrg def __str__(self): 118b8e80941Smrg return self.in_val 119b8e80941Smrg 120b8e80941Smrg def get_bit_size(self): 121b8e80941Smrg """Get the physical bit-size that has been chosen for this value, or if 122b8e80941Smrg there is none, the canonical value which currently represents this 123b8e80941Smrg bit-size class. Variables will be preferred, i.e. if there are any 124b8e80941Smrg variables in the equivalence class, the canonical value will be a 125b8e80941Smrg variable. We do this since we'll need to know which variable each value 126b8e80941Smrg is equivalent to when constructing the replacement expression. This is 127b8e80941Smrg the "find" part of the union-find algorithm. 128b8e80941Smrg """ 129b8e80941Smrg bit_size = self 130b8e80941Smrg 131b8e80941Smrg while isinstance(bit_size, Value): 132b8e80941Smrg if bit_size._bit_size is None: 133b8e80941Smrg break 134b8e80941Smrg bit_size = bit_size._bit_size 135b8e80941Smrg 136b8e80941Smrg if bit_size is not self: 137b8e80941Smrg self._bit_size = bit_size 138b8e80941Smrg return bit_size 139b8e80941Smrg 140b8e80941Smrg def set_bit_size(self, other): 141b8e80941Smrg """Make self.get_bit_size() return what other.get_bit_size() return 142b8e80941Smrg before calling this, or just "other" if it's a concrete bit-size. This is 143b8e80941Smrg the "union" part of the union-find algorithm. 144b8e80941Smrg """ 145b8e80941Smrg 146b8e80941Smrg self_bit_size = self.get_bit_size() 147b8e80941Smrg other_bit_size = other if isinstance(other, int) else other.get_bit_size() 148b8e80941Smrg 149b8e80941Smrg if self_bit_size == other_bit_size: 150b8e80941Smrg return 151b8e80941Smrg 152b8e80941Smrg self_bit_size._bit_size = other_bit_size 153b8e80941Smrg 154b8e80941Smrg @property 155b8e80941Smrg def type_enum(self): 156b8e80941Smrg return "nir_search_value_" + self.type_str 157b8e80941Smrg 158b8e80941Smrg @property 159b8e80941Smrg def c_type(self): 160b8e80941Smrg return "nir_search_" + self.type_str 161b8e80941Smrg 162b8e80941Smrg def __c_name(self, cache): 163b8e80941Smrg if cache is not None and self.name in cache: 164b8e80941Smrg return cache[self.name] 165b8e80941Smrg else: 166b8e80941Smrg return self.name 167b8e80941Smrg 168b8e80941Smrg def c_value_ptr(self, cache): 169b8e80941Smrg return "&{0}.value".format(self.__c_name(cache)) 170b8e80941Smrg 171b8e80941Smrg def c_ptr(self, cache): 172b8e80941Smrg return "&{0}".format(self.__c_name(cache)) 173b8e80941Smrg 174b8e80941Smrg @property 175b8e80941Smrg def c_bit_size(self): 176b8e80941Smrg bit_size = self.get_bit_size() 177b8e80941Smrg if isinstance(bit_size, int): 178b8e80941Smrg return bit_size 179b8e80941Smrg elif isinstance(bit_size, Variable): 180b8e80941Smrg return -bit_size.index - 1 181b8e80941Smrg else: 182b8e80941Smrg # If the bit-size class is neither a variable, nor an actual bit-size, then 183b8e80941Smrg # - If it's in the search expression, we don't need to check anything 184b8e80941Smrg # - If it's in the replace expression, either it's ambiguous (in which 185b8e80941Smrg # case we'd reject it), or it equals the bit-size of the search value 186b8e80941Smrg # We represent these cases with a 0 bit-size. 187b8e80941Smrg return 0 188b8e80941Smrg 189b8e80941Smrg __template = mako.template.Template("""{ 190b8e80941Smrg { ${val.type_enum}, ${val.c_bit_size} }, 191b8e80941Smrg% if isinstance(val, Constant): 192b8e80941Smrg ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 193b8e80941Smrg% elif isinstance(val, Variable): 194b8e80941Smrg ${val.index}, /* ${val.var_name} */ 195b8e80941Smrg ${'true' if val.is_constant else 'false'}, 196b8e80941Smrg ${val.type() or 'nir_type_invalid' }, 197b8e80941Smrg ${val.cond if val.cond else 'NULL'}, 198b8e80941Smrg% elif isinstance(val, Expression): 199b8e80941Smrg ${'true' if val.inexact else 'false'}, 200b8e80941Smrg ${val.comm_expr_idx}, ${val.comm_exprs}, 201b8e80941Smrg ${val.c_opcode()}, 202b8e80941Smrg { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} }, 203b8e80941Smrg ${val.cond if val.cond else 'NULL'}, 204b8e80941Smrg% endif 205b8e80941Smrg};""") 206b8e80941Smrg 207b8e80941Smrg def render(self, cache): 208b8e80941Smrg struct_init = self.__template.render(val=self, cache=cache, 209b8e80941Smrg Constant=Constant, 210b8e80941Smrg Variable=Variable, 211b8e80941Smrg Expression=Expression) 212b8e80941Smrg if cache is not None and struct_init in cache: 213b8e80941Smrg # If it's in the cache, register a name remap in the cache and render 214b8e80941Smrg # only a comment saying it's been remapped 215b8e80941Smrg cache[self.name] = cache[struct_init] 216b8e80941Smrg return "/* {} -> {} in the cache */\n".format(self.name, 217b8e80941Smrg cache[struct_init]) 218b8e80941Smrg else: 219b8e80941Smrg if cache is not None: 220b8e80941Smrg cache[struct_init] = self.name 221b8e80941Smrg return "static const {} {} = {}\n".format(self.c_type, self.name, 222b8e80941Smrg struct_init) 223b8e80941Smrg 224b8e80941Smrg_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 225b8e80941Smrg 226b8e80941Smrgclass Constant(Value): 227b8e80941Smrg def __init__(self, val, name): 228b8e80941Smrg Value.__init__(self, val, name, "constant") 229b8e80941Smrg 230b8e80941Smrg if isinstance(val, (str)): 231b8e80941Smrg m = _constant_re.match(val) 232b8e80941Smrg self.value = ast.literal_eval(m.group('value')) 233b8e80941Smrg self._bit_size = int(m.group('bits')) if m.group('bits') else None 234b8e80941Smrg else: 235b8e80941Smrg self.value = val 236b8e80941Smrg self._bit_size = None 237b8e80941Smrg 238b8e80941Smrg if isinstance(self.value, bool): 239b8e80941Smrg assert self._bit_size is None or self._bit_size == 1 240b8e80941Smrg self._bit_size = 1 241b8e80941Smrg 242b8e80941Smrg def hex(self): 243b8e80941Smrg if isinstance(self.value, (bool)): 244b8e80941Smrg return 'NIR_TRUE' if self.value else 'NIR_FALSE' 245b8e80941Smrg if isinstance(self.value, integer_types): 246b8e80941Smrg return hex(self.value) 247b8e80941Smrg elif isinstance(self.value, float): 248b8e80941Smrg i = struct.unpack('Q', struct.pack('d', self.value))[0] 249b8e80941Smrg h = hex(i) 250b8e80941Smrg 251b8e80941Smrg # On Python 2 this 'L' suffix is automatically added, but not on Python 3 252b8e80941Smrg # Adding it explicitly makes the generated file identical, regardless 253b8e80941Smrg # of the Python version running this script. 254b8e80941Smrg if h[-1] != 'L' and i > sys.maxsize: 255b8e80941Smrg h += 'L' 256b8e80941Smrg 257b8e80941Smrg return h 258b8e80941Smrg else: 259b8e80941Smrg assert False 260b8e80941Smrg 261b8e80941Smrg def type(self): 262b8e80941Smrg if isinstance(self.value, (bool)): 263b8e80941Smrg return "nir_type_bool" 264b8e80941Smrg elif isinstance(self.value, integer_types): 265b8e80941Smrg return "nir_type_int" 266b8e80941Smrg elif isinstance(self.value, float): 267b8e80941Smrg return "nir_type_float" 268b8e80941Smrg 269b8e80941Smrg_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 270b8e80941Smrg r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 271b8e80941Smrg r"(?P<cond>\([^\)]+\))?") 272b8e80941Smrg 273b8e80941Smrgclass Variable(Value): 274b8e80941Smrg def __init__(self, val, name, varset): 275b8e80941Smrg Value.__init__(self, val, name, "variable") 276b8e80941Smrg 277b8e80941Smrg m = _var_name_re.match(val) 278b8e80941Smrg assert m and m.group('name') is not None 279b8e80941Smrg 280b8e80941Smrg self.var_name = m.group('name') 281b8e80941Smrg 282b8e80941Smrg # Prevent common cases where someone puts quotes around a literal 283b8e80941Smrg # constant. If we want to support names that have numeric or 284b8e80941Smrg # punctuation characters, we can me the first assertion more flexible. 285b8e80941Smrg assert self.var_name.isalpha() 286b8e80941Smrg assert self.var_name is not 'True' 287b8e80941Smrg assert self.var_name is not 'False' 288b8e80941Smrg 289b8e80941Smrg self.is_constant = m.group('const') is not None 290b8e80941Smrg self.cond = m.group('cond') 291b8e80941Smrg self.required_type = m.group('type') 292b8e80941Smrg self._bit_size = int(m.group('bits')) if m.group('bits') else None 293b8e80941Smrg 294b8e80941Smrg if self.required_type == 'bool': 295b8e80941Smrg if self._bit_size is not None: 296b8e80941Smrg assert self._bit_size in type_sizes(self.required_type) 297b8e80941Smrg else: 298b8e80941Smrg self._bit_size = 1 299b8e80941Smrg 300b8e80941Smrg if self.required_type is not None: 301b8e80941Smrg assert self.required_type in ('float', 'bool', 'int', 'uint') 302b8e80941Smrg 303b8e80941Smrg self.index = varset[self.var_name] 304b8e80941Smrg 305b8e80941Smrg def type(self): 306b8e80941Smrg if self.required_type == 'bool': 307b8e80941Smrg return "nir_type_bool" 308b8e80941Smrg elif self.required_type in ('int', 'uint'): 309b8e80941Smrg return "nir_type_int" 310b8e80941Smrg elif self.required_type == 'float': 311b8e80941Smrg return "nir_type_float" 312b8e80941Smrg 313b8e80941Smrg_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 314b8e80941Smrg r"(?P<cond>\([^\)]+\))?") 315b8e80941Smrg 316b8e80941Smrgclass Expression(Value): 317b8e80941Smrg def __init__(self, expr, name_base, varset): 318b8e80941Smrg Value.__init__(self, expr, name_base, "expression") 319b8e80941Smrg assert isinstance(expr, tuple) 320b8e80941Smrg 321b8e80941Smrg m = _opcode_re.match(expr[0]) 322b8e80941Smrg assert m and m.group('opcode') is not None 323b8e80941Smrg 324b8e80941Smrg self.opcode = m.group('opcode') 325b8e80941Smrg self._bit_size = int(m.group('bits')) if m.group('bits') else None 326b8e80941Smrg self.inexact = m.group('inexact') is not None 327b8e80941Smrg self.cond = m.group('cond') 328b8e80941Smrg self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) 329b8e80941Smrg for (i, src) in enumerate(expr[1:]) ] 330b8e80941Smrg 331b8e80941Smrg if self.opcode in conv_opcode_types: 332b8e80941Smrg assert self._bit_size is None, \ 333b8e80941Smrg 'Expression cannot use an unsized conversion opcode with ' \ 334b8e80941Smrg 'an explicit size; that\'s silly.' 335b8e80941Smrg 336b8e80941Smrg self.__index_comm_exprs(0) 337b8e80941Smrg 338b8e80941Smrg def __index_comm_exprs(self, base_idx): 339b8e80941Smrg """Recursively count and index commutative expressions 340b8e80941Smrg """ 341b8e80941Smrg self.comm_exprs = 0 342b8e80941Smrg if self.opcode not in conv_opcode_types and \ 343b8e80941Smrg "commutative" in opcodes[self.opcode].algebraic_properties: 344b8e80941Smrg self.comm_expr_idx = base_idx 345b8e80941Smrg self.comm_exprs += 1 346b8e80941Smrg else: 347b8e80941Smrg self.comm_expr_idx = -1 348b8e80941Smrg 349b8e80941Smrg for s in self.sources: 350b8e80941Smrg if isinstance(s, Expression): 351b8e80941Smrg s.__index_comm_exprs(base_idx + self.comm_exprs) 352b8e80941Smrg self.comm_exprs += s.comm_exprs 353b8e80941Smrg 354b8e80941Smrg return self.comm_exprs 355b8e80941Smrg 356b8e80941Smrg def c_opcode(self): 357b8e80941Smrg return get_c_opcode(self.opcode) 358b8e80941Smrg 359b8e80941Smrg def render(self, cache): 360b8e80941Smrg srcs = "\n".join(src.render(cache) for src in self.sources) 361b8e80941Smrg return srcs + super(Expression, self).render(cache) 362b8e80941Smrg 363b8e80941Smrgclass BitSizeValidator(object): 364b8e80941Smrg """A class for validating bit sizes of expressions. 365b8e80941Smrg 366b8e80941Smrg NIR supports multiple bit-sizes on expressions in order to handle things 367b8e80941Smrg such as fp64. The source and destination of every ALU operation is 368b8e80941Smrg assigned a type and that type may or may not specify a bit size. Sources 369b8e80941Smrg and destinations whose type does not specify a bit size are considered 370b8e80941Smrg "unsized" and automatically take on the bit size of the corresponding 371b8e80941Smrg register or SSA value. NIR has two simple rules for bit sizes that are 372b8e80941Smrg validated by nir_validator: 373b8e80941Smrg 374b8e80941Smrg 1) A given SSA def or register has a single bit size that is respected by 375b8e80941Smrg everything that reads from it or writes to it. 376b8e80941Smrg 377b8e80941Smrg 2) The bit sizes of all unsized inputs/outputs on any given ALU 378b8e80941Smrg instruction must match. They need not match the sized inputs or 379b8e80941Smrg outputs but they must match each other. 380b8e80941Smrg 381b8e80941Smrg In order to keep nir_algebraic relatively simple and easy-to-use, 382b8e80941Smrg nir_search supports a type of bit-size inference based on the two rules 383b8e80941Smrg above. This is similar to type inference in many common programming 384b8e80941Smrg languages. If, for instance, you are constructing an add operation and you 385b8e80941Smrg know the second source is 16-bit, then you know that the other source and 386b8e80941Smrg the destination must also be 16-bit. There are, however, cases where this 387b8e80941Smrg inference can be ambiguous or contradictory. Consider, for instance, the 388b8e80941Smrg following transformation: 389b8e80941Smrg 390b8e80941Smrg (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 391b8e80941Smrg 392b8e80941Smrg This transformation can potentially cause a problem because usub_borrow is 393b8e80941Smrg well-defined for any bit-size of integer. However, b2i always generates a 394b8e80941Smrg 32-bit result so it could end up replacing a 64-bit expression with one 395b8e80941Smrg that takes two 64-bit values and produces a 32-bit value. As another 396b8e80941Smrg example, consider this expression: 397b8e80941Smrg 398b8e80941Smrg (('bcsel', a, b, 0), ('iand', a, b)) 399b8e80941Smrg 400b8e80941Smrg In this case, in the search expression a must be 32-bit but b can 401b8e80941Smrg potentially have any bit size. If we had a 64-bit b value, we would end up 402b8e80941Smrg trying to and a 32-bit value with a 64-bit value which would be invalid 403b8e80941Smrg 404b8e80941Smrg This class solves that problem by providing a validation layer that proves 405b8e80941Smrg that a given search-and-replace operation is 100% well-defined before we 406b8e80941Smrg generate any code. This ensures that bugs are caught at compile time 407b8e80941Smrg rather than at run time. 408b8e80941Smrg 409b8e80941Smrg Each value maintains a "bit-size class", which is either an actual bit size 410b8e80941Smrg or an equivalence class with other values that must have the same bit size. 411b8e80941Smrg The validator works by combining bit-size classes with each other according 412b8e80941Smrg to the NIR rules outlined above, checking that there are no inconsistencies. 413b8e80941Smrg When doing this for the replacement expression, we make sure to never change 414b8e80941Smrg the equivalence class of any of the search values. We could make the example 415b8e80941Smrg transforms above work by doing some extra run-time checking of the search 416b8e80941Smrg expression, but we make the user specify those constraints themselves, to 417b8e80941Smrg avoid any surprises. Since the replacement bitsizes can only be connected to 418b8e80941Smrg the source bitsize via variables (variables must have the same bitsize in 419b8e80941Smrg the source and replacment expressions) or the roots of the expression (the 420b8e80941Smrg replacement expression must produce the same bit size as the search 421b8e80941Smrg expression), we prevent merging a variable with anything when processing the 422b8e80941Smrg replacement expression, or specializing the search bitsize 423b8e80941Smrg with anything. The former prevents 424b8e80941Smrg 425b8e80941Smrg (('bcsel', a, b, 0), ('iand', a, b)) 426b8e80941Smrg 427b8e80941Smrg from being allowed, since we'd have to merge the bitsizes for a and b due to 428b8e80941Smrg the 'iand', while the latter prevents 429b8e80941Smrg 430b8e80941Smrg (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 431b8e80941Smrg 432b8e80941Smrg from being allowed, since the search expression has the bit size of a and b, 433b8e80941Smrg which can't be specialized to 32 which is the bitsize of the replace 434b8e80941Smrg expression. It also prevents something like: 435b8e80941Smrg 436b8e80941Smrg (('b2i', ('i2b', a)), ('ineq', a, 0)) 437b8e80941Smrg 438b8e80941Smrg since the bitsize of 'b2i', which can be anything, can't be specialized to 439b8e80941Smrg the bitsize of a. 440b8e80941Smrg 441b8e80941Smrg After doing all this, we check that every subexpression of the replacement 442b8e80941Smrg was assigned a constant bitsize, the bitsize of a variable, or the bitsize 443b8e80941Smrg of the search expresssion, since those are the things that are known when 444b8e80941Smrg constructing the replacement expresssion. Finally, we record the bitsize 445b8e80941Smrg needed in nir_search_value so that we know what to do when building the 446b8e80941Smrg replacement expression. 447b8e80941Smrg """ 448b8e80941Smrg 449b8e80941Smrg def __init__(self, varset): 450b8e80941Smrg self._var_classes = [None] * len(varset.names) 451b8e80941Smrg 452b8e80941Smrg def compare_bitsizes(self, a, b): 453b8e80941Smrg """Determines which bitsize class is a specialization of the other, or 454b8e80941Smrg whether neither is. When we merge two different bitsizes, the 455b8e80941Smrg less-specialized bitsize always points to the more-specialized one, so 456b8e80941Smrg that calling get_bit_size() always gets you the most specialized bitsize. 457b8e80941Smrg The specialization partial order is given by: 458b8e80941Smrg - Physical bitsizes are always the most specialized, and a different 459b8e80941Smrg bitsize can never specialize another. 460b8e80941Smrg - In the search expression, variables can always be specialized to each 461b8e80941Smrg other and to physical bitsizes. In the replace expression, we disallow 462b8e80941Smrg this to avoid adding extra constraints to the search expression that 463b8e80941Smrg the user didn't specify. 464b8e80941Smrg - Expressions and constants without a bitsize can always be specialized to 465b8e80941Smrg each other and variables, but not the other way around. 466b8e80941Smrg 467b8e80941Smrg We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 468b8e80941Smrg and None if they are not comparable (neither a <= b nor b <= a). 469b8e80941Smrg """ 470b8e80941Smrg if isinstance(a, int): 471b8e80941Smrg if isinstance(b, int): 472b8e80941Smrg return 0 if a == b else None 473b8e80941Smrg elif isinstance(b, Variable): 474b8e80941Smrg return -1 if self.is_search else None 475b8e80941Smrg else: 476b8e80941Smrg return -1 477b8e80941Smrg elif isinstance(a, Variable): 478b8e80941Smrg if isinstance(b, int): 479b8e80941Smrg return 1 if self.is_search else None 480b8e80941Smrg elif isinstance(b, Variable): 481b8e80941Smrg return 0 if self.is_search or a.index == b.index else None 482b8e80941Smrg else: 483b8e80941Smrg return -1 484b8e80941Smrg else: 485b8e80941Smrg if isinstance(b, int): 486b8e80941Smrg return 1 487b8e80941Smrg elif isinstance(b, Variable): 488b8e80941Smrg return 1 489b8e80941Smrg else: 490b8e80941Smrg return 0 491b8e80941Smrg 492b8e80941Smrg def unify_bit_size(self, a, b, error_msg): 493b8e80941Smrg """Record that a must have the same bit-size as b. If both 494b8e80941Smrg have been assigned conflicting physical bit-sizes, call "error_msg" with 495b8e80941Smrg the bit-sizes of self and other to get a message and raise an error. 496b8e80941Smrg In the replace expression, disallow merging variables with other 497b8e80941Smrg variables and physical bit-sizes as well. 498b8e80941Smrg """ 499b8e80941Smrg a_bit_size = a.get_bit_size() 500b8e80941Smrg b_bit_size = b if isinstance(b, int) else b.get_bit_size() 501b8e80941Smrg 502b8e80941Smrg cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 503b8e80941Smrg 504b8e80941Smrg assert cmp_result is not None, \ 505b8e80941Smrg error_msg(a_bit_size, b_bit_size) 506b8e80941Smrg 507b8e80941Smrg if cmp_result < 0: 508b8e80941Smrg b_bit_size.set_bit_size(a) 509b8e80941Smrg elif not isinstance(a_bit_size, int): 510b8e80941Smrg a_bit_size.set_bit_size(b) 511b8e80941Smrg 512b8e80941Smrg def merge_variables(self, val): 513b8e80941Smrg """Perform the first part of type inference by merging all the different 514b8e80941Smrg uses of the same variable. We always do this as if we're in the search 515b8e80941Smrg expression, even if we're actually not, since otherwise we'd get errors 516b8e80941Smrg if the search expression specified some constraint but the replace 517b8e80941Smrg expression didn't, because we'd be merging a variable and a constant. 518b8e80941Smrg """ 519b8e80941Smrg if isinstance(val, Variable): 520b8e80941Smrg if self._var_classes[val.index] is None: 521b8e80941Smrg self._var_classes[val.index] = val 522b8e80941Smrg else: 523b8e80941Smrg other = self._var_classes[val.index] 524b8e80941Smrg self.unify_bit_size(other, val, 525b8e80941Smrg lambda other_bit_size, bit_size: 526b8e80941Smrg 'Variable {} has conflicting bit size requirements: ' \ 527b8e80941Smrg 'it must have bit size {} and {}'.format( 528b8e80941Smrg val.var_name, other_bit_size, bit_size)) 529b8e80941Smrg elif isinstance(val, Expression): 530b8e80941Smrg for src in val.sources: 531b8e80941Smrg self.merge_variables(src) 532b8e80941Smrg 533b8e80941Smrg def validate_value(self, val): 534b8e80941Smrg """Validate the an expression by performing classic Hindley-Milner 535b8e80941Smrg type inference on bitsizes. This will detect if there are any conflicting 536b8e80941Smrg requirements, and unify variables so that we know which variables must 537b8e80941Smrg have the same bitsize. If we're operating on the replace expression, we 538b8e80941Smrg will refuse to merge different variables together or merge a variable 539b8e80941Smrg with a constant, in order to prevent surprises due to rules unexpectedly 540b8e80941Smrg not matching at runtime. 541b8e80941Smrg """ 542b8e80941Smrg if not isinstance(val, Expression): 543b8e80941Smrg return 544b8e80941Smrg 545b8e80941Smrg # Generic conversion ops are special in that they have a single unsized 546b8e80941Smrg # source and an unsized destination and the two don't have to match. 547b8e80941Smrg # This means there's no validation or unioning to do here besides the 548b8e80941Smrg # len(val.sources) check. 549b8e80941Smrg if val.opcode in conv_opcode_types: 550b8e80941Smrg assert len(val.sources) == 1, \ 551b8e80941Smrg "Expression {} has {} sources, expected 1".format( 552b8e80941Smrg val, len(val.sources)) 553b8e80941Smrg self.validate_value(val.sources[0]) 554b8e80941Smrg return 555b8e80941Smrg 556b8e80941Smrg nir_op = opcodes[val.opcode] 557b8e80941Smrg assert len(val.sources) == nir_op.num_inputs, \ 558b8e80941Smrg "Expression {} has {} sources, expected {}".format( 559b8e80941Smrg val, len(val.sources), nir_op.num_inputs) 560b8e80941Smrg 561b8e80941Smrg for src in val.sources: 562b8e80941Smrg self.validate_value(src) 563b8e80941Smrg 564b8e80941Smrg dst_type_bits = type_bits(nir_op.output_type) 565b8e80941Smrg 566b8e80941Smrg # First, unify all the sources. That way, an error coming up because two 567b8e80941Smrg # sources have an incompatible bit-size won't produce an error message 568b8e80941Smrg # involving the destination. 569b8e80941Smrg first_unsized_src = None 570b8e80941Smrg for src_type, src in zip(nir_op.input_types, val.sources): 571b8e80941Smrg src_type_bits = type_bits(src_type) 572b8e80941Smrg if src_type_bits == 0: 573b8e80941Smrg if first_unsized_src is None: 574b8e80941Smrg first_unsized_src = src 575b8e80941Smrg continue 576b8e80941Smrg 577b8e80941Smrg if self.is_search: 578b8e80941Smrg self.unify_bit_size(first_unsized_src, src, 579b8e80941Smrg lambda first_unsized_src_bit_size, src_bit_size: 580b8e80941Smrg 'Source {} of {} must have bit size {}, while source {} ' \ 581b8e80941Smrg 'must have incompatible bit size {}'.format( 582b8e80941Smrg first_unsized_src, val, first_unsized_src_bit_size, 583b8e80941Smrg src, src_bit_size)) 584b8e80941Smrg else: 585b8e80941Smrg self.unify_bit_size(first_unsized_src, src, 586b8e80941Smrg lambda first_unsized_src_bit_size, src_bit_size: 587b8e80941Smrg 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 588b8e80941Smrg 'of {} may not have the same bit size when building the ' \ 589b8e80941Smrg 'replacement expression.'.format( 590b8e80941Smrg first_unsized_src, first_unsized_src_bit_size, src, 591b8e80941Smrg src_bit_size, val)) 592b8e80941Smrg else: 593b8e80941Smrg if self.is_search: 594b8e80941Smrg self.unify_bit_size(src, src_type_bits, 595b8e80941Smrg lambda src_bit_size, unused: 596b8e80941Smrg '{} must have {} bits, but as a source of nir_op_{} '\ 597b8e80941Smrg 'it must have {} bits'.format( 598b8e80941Smrg src, src_bit_size, nir_op.name, src_type_bits)) 599b8e80941Smrg else: 600b8e80941Smrg self.unify_bit_size(src, src_type_bits, 601b8e80941Smrg lambda src_bit_size, unused: 602b8e80941Smrg '{} has the bit size of {}, but as a source of ' \ 603b8e80941Smrg 'nir_op_{} it must have {} bits, which may not be the ' \ 604b8e80941Smrg 'same'.format( 605b8e80941Smrg src, src_bit_size, nir_op.name, src_type_bits)) 606b8e80941Smrg 607b8e80941Smrg if dst_type_bits == 0: 608b8e80941Smrg if first_unsized_src is not None: 609b8e80941Smrg if self.is_search: 610b8e80941Smrg self.unify_bit_size(val, first_unsized_src, 611b8e80941Smrg lambda val_bit_size, src_bit_size: 612b8e80941Smrg '{} must have the bit size of {}, while its source {} ' \ 613b8e80941Smrg 'must have incompatible bit size {}'.format( 614b8e80941Smrg val, val_bit_size, first_unsized_src, src_bit_size)) 615b8e80941Smrg else: 616b8e80941Smrg self.unify_bit_size(val, first_unsized_src, 617b8e80941Smrg lambda val_bit_size, src_bit_size: 618b8e80941Smrg '{} must have {} bits, but its source {} ' \ 619b8e80941Smrg '(bit size of {}) may not have that bit size ' \ 620b8e80941Smrg 'when building the replacement.'.format( 621b8e80941Smrg val, val_bit_size, first_unsized_src, src_bit_size)) 622b8e80941Smrg else: 623b8e80941Smrg self.unify_bit_size(val, dst_type_bits, 624b8e80941Smrg lambda dst_bit_size, unused: 625b8e80941Smrg '{} must have {} bits, but as a destination of nir_op_{} ' \ 626b8e80941Smrg 'it must have {} bits'.format( 627b8e80941Smrg val, dst_bit_size, nir_op.name, dst_type_bits)) 628b8e80941Smrg 629b8e80941Smrg def validate_replace(self, val, search): 630b8e80941Smrg bit_size = val.get_bit_size() 631b8e80941Smrg assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 632b8e80941Smrg bit_size == search.get_bit_size(), \ 633b8e80941Smrg 'Ambiguous bit size for replacement value {}: ' \ 634b8e80941Smrg 'it cannot be deduced from a variable, a fixed bit size ' \ 635b8e80941Smrg 'somewhere, or the search expression.'.format(val) 636b8e80941Smrg 637b8e80941Smrg if isinstance(val, Expression): 638b8e80941Smrg for src in val.sources: 639b8e80941Smrg self.validate_replace(src, search) 640b8e80941Smrg 641b8e80941Smrg def validate(self, search, replace): 642b8e80941Smrg self.is_search = True 643b8e80941Smrg self.merge_variables(search) 644b8e80941Smrg self.merge_variables(replace) 645b8e80941Smrg self.validate_value(search) 646b8e80941Smrg 647b8e80941Smrg self.is_search = False 648b8e80941Smrg self.validate_value(replace) 649b8e80941Smrg 650b8e80941Smrg # Check that search is always more specialized than replace. Note that 651b8e80941Smrg # we're doing this in replace mode, disallowing merging variables. 652b8e80941Smrg search_bit_size = search.get_bit_size() 653b8e80941Smrg replace_bit_size = replace.get_bit_size() 654b8e80941Smrg cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 655b8e80941Smrg 656b8e80941Smrg assert cmp_result is not None and cmp_result <= 0, \ 657b8e80941Smrg 'The search expression bit size {} and replace expression ' \ 658b8e80941Smrg 'bit size {} may not be the same'.format( 659b8e80941Smrg search_bit_size, replace_bit_size) 660b8e80941Smrg 661b8e80941Smrg replace.set_bit_size(search) 662b8e80941Smrg 663b8e80941Smrg self.validate_replace(replace, search) 664b8e80941Smrg 665b8e80941Smrg_optimization_ids = itertools.count() 666b8e80941Smrg 667b8e80941Smrgcondition_list = ['true'] 668b8e80941Smrg 669b8e80941Smrgclass SearchAndReplace(object): 670b8e80941Smrg def __init__(self, transform): 671b8e80941Smrg self.id = next(_optimization_ids) 672b8e80941Smrg 673b8e80941Smrg search = transform[0] 674b8e80941Smrg replace = transform[1] 675b8e80941Smrg if len(transform) > 2: 676b8e80941Smrg self.condition = transform[2] 677b8e80941Smrg else: 678b8e80941Smrg self.condition = 'true' 679b8e80941Smrg 680b8e80941Smrg if self.condition not in condition_list: 681b8e80941Smrg condition_list.append(self.condition) 682b8e80941Smrg self.condition_index = condition_list.index(self.condition) 683b8e80941Smrg 684b8e80941Smrg varset = VarSet() 685b8e80941Smrg if isinstance(search, Expression): 686b8e80941Smrg self.search = search 687b8e80941Smrg else: 688b8e80941Smrg self.search = Expression(search, "search{0}".format(self.id), varset) 689b8e80941Smrg 690b8e80941Smrg varset.lock() 691b8e80941Smrg 692b8e80941Smrg if isinstance(replace, Value): 693b8e80941Smrg self.replace = replace 694b8e80941Smrg else: 695b8e80941Smrg self.replace = Value.create(replace, "replace{0}".format(self.id), varset) 696b8e80941Smrg 697b8e80941Smrg BitSizeValidator(varset).validate(self.search, self.replace) 698b8e80941Smrg 699b8e80941Smrgclass TreeAutomaton(object): 700b8e80941Smrg """This class calculates a bottom-up tree automaton to quickly search for 701b8e80941Smrg the left-hand sides of tranforms. Tree automatons are a generalization of 702b8e80941Smrg classical NFA's and DFA's, where the transition function determines the 703b8e80941Smrg state of the parent node based on the state of its children. We construct a 704b8e80941Smrg deterministic automaton to match patterns, using a similar algorithm to the 705b8e80941Smrg classical NFA to DFA construction. At the moment, it only matches opcodes 706b8e80941Smrg and constants (without checking the actual value), leaving more detailed 707b8e80941Smrg checking to the search function which actually checks the leaves. The 708b8e80941Smrg automaton acts as a quick filter for the search function, requiring only n 709b8e80941Smrg + 1 table lookups for each n-source operation. The implementation is based 710b8e80941Smrg on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 711b8e80941Smrg In the language of that reference, this is a frontier-to-root deterministic 712b8e80941Smrg automaton using only symbol filtering. The filtering is crucial to reduce 713b8e80941Smrg both the time taken to generate the tables and the size of the tables. 714b8e80941Smrg """ 715b8e80941Smrg def __init__(self, transforms): 716b8e80941Smrg self.patterns = [t.search for t in transforms] 717b8e80941Smrg self._compute_items() 718b8e80941Smrg self._build_table() 719b8e80941Smrg #print('num items: {}'.format(len(set(self.items.values())))) 720b8e80941Smrg #print('num states: {}'.format(len(self.states))) 721b8e80941Smrg #for state, patterns in zip(self.states, self.patterns): 722b8e80941Smrg # print('{}: num patterns: {}'.format(state, len(patterns))) 723b8e80941Smrg 724b8e80941Smrg class IndexMap(object): 725b8e80941Smrg """An indexed list of objects, where one can either lookup an object by 726b8e80941Smrg index or find the index associated to an object quickly using a hash 727b8e80941Smrg table. Compared to a list, it has a constant time index(). Compared to a 728b8e80941Smrg set, it provides a stable iteration order. 729b8e80941Smrg """ 730b8e80941Smrg def __init__(self, iterable=()): 731b8e80941Smrg self.objects = [] 732b8e80941Smrg self.map = {} 733b8e80941Smrg for obj in iterable: 734b8e80941Smrg self.add(obj) 735b8e80941Smrg 736b8e80941Smrg def __getitem__(self, i): 737b8e80941Smrg return self.objects[i] 738b8e80941Smrg 739b8e80941Smrg def __contains__(self, obj): 740b8e80941Smrg return obj in self.map 741b8e80941Smrg 742b8e80941Smrg def __len__(self): 743b8e80941Smrg return len(self.objects) 744b8e80941Smrg 745b8e80941Smrg def __iter__(self): 746b8e80941Smrg return iter(self.objects) 747b8e80941Smrg 748b8e80941Smrg def clear(self): 749b8e80941Smrg self.objects = [] 750b8e80941Smrg self.map.clear() 751b8e80941Smrg 752b8e80941Smrg def index(self, obj): 753b8e80941Smrg return self.map[obj] 754b8e80941Smrg 755b8e80941Smrg def add(self, obj): 756b8e80941Smrg if obj in self.map: 757b8e80941Smrg return self.map[obj] 758b8e80941Smrg else: 759b8e80941Smrg index = len(self.objects) 760b8e80941Smrg self.objects.append(obj) 761b8e80941Smrg self.map[obj] = index 762b8e80941Smrg return index 763b8e80941Smrg 764b8e80941Smrg def __repr__(self): 765b8e80941Smrg return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 766b8e80941Smrg 767b8e80941Smrg class Item(object): 768b8e80941Smrg """This represents an "item" in the language of "Tree Automatons." This 769b8e80941Smrg is just a subtree of some pattern, which represents a potential partial 770b8e80941Smrg match at runtime. We deduplicate them, so that identical subtrees of 771b8e80941Smrg different patterns share the same object, and store some extra 772b8e80941Smrg information needed for the main algorithm as well. 773b8e80941Smrg """ 774b8e80941Smrg def __init__(self, opcode, children): 775b8e80941Smrg self.opcode = opcode 776b8e80941Smrg self.children = children 777b8e80941Smrg # These are the indices of patterns for which this item is the root node. 778b8e80941Smrg self.patterns = [] 779b8e80941Smrg # This the set of opcodes for parents of this item. Used to speed up 780b8e80941Smrg # filtering. 781b8e80941Smrg self.parent_ops = set() 782b8e80941Smrg 783b8e80941Smrg def __str__(self): 784b8e80941Smrg return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 785b8e80941Smrg 786b8e80941Smrg def __repr__(self): 787b8e80941Smrg return str(self) 788b8e80941Smrg 789b8e80941Smrg def _compute_items(self): 790b8e80941Smrg """Build a set of all possible items, deduplicating them.""" 791b8e80941Smrg # This is a map from (opcode, sources) to item. 792b8e80941Smrg self.items = {} 793b8e80941Smrg 794b8e80941Smrg # The set of all opcodes used by the patterns. Used later to avoid 795b8e80941Smrg # building and emitting all the tables for opcodes that aren't used. 796b8e80941Smrg self.opcodes = self.IndexMap() 797b8e80941Smrg 798b8e80941Smrg def get_item(opcode, children, pattern=None): 799b8e80941Smrg commutative = len(children) == 2 \ 800b8e80941Smrg and "commutative" in opcodes[opcode].algebraic_properties 801b8e80941Smrg item = self.items.setdefault((opcode, children), 802b8e80941Smrg self.Item(opcode, children)) 803b8e80941Smrg if commutative: 804b8e80941Smrg self.items[opcode, (children[1], children[0])] = item 805b8e80941Smrg if pattern is not None: 806b8e80941Smrg item.patterns.append(pattern) 807b8e80941Smrg return item 808b8e80941Smrg 809b8e80941Smrg self.wildcard = get_item("__wildcard", ()) 810b8e80941Smrg self.const = get_item("__const", ()) 811b8e80941Smrg 812b8e80941Smrg def process_subpattern(src, pattern=None): 813b8e80941Smrg if isinstance(src, Constant): 814b8e80941Smrg # Note: we throw away the actual constant value! 815b8e80941Smrg return self.const 816b8e80941Smrg elif isinstance(src, Variable): 817b8e80941Smrg if src.is_constant: 818b8e80941Smrg return self.const 819b8e80941Smrg else: 820b8e80941Smrg # Note: we throw away which variable it is here! This special 821b8e80941Smrg # item is equivalent to nu in "Tree Automatons." 822b8e80941Smrg return self.wildcard 823b8e80941Smrg else: 824b8e80941Smrg assert isinstance(src, Expression) 825b8e80941Smrg opcode = src.opcode 826b8e80941Smrg stripped = opcode.rstrip('0123456789') 827b8e80941Smrg if stripped in conv_opcode_types: 828b8e80941Smrg # Matches that use conversion opcodes with a specific type, 829b8e80941Smrg # like f2b1, are tricky. Either we construct the automaton to 830b8e80941Smrg # match specific NIR opcodes like nir_op_f2b1, in which case we 831b8e80941Smrg # need to create separate items for each possible NIR opcode 832b8e80941Smrg # for patterns that have a generic opcode like f2b, or we 833b8e80941Smrg # construct it to match the search opcode, in which case we 834b8e80941Smrg # need to map f2b1 to f2b when constructing the automaton. Here 835b8e80941Smrg # we do the latter. 836b8e80941Smrg opcode = stripped 837b8e80941Smrg self.opcodes.add(opcode) 838b8e80941Smrg children = tuple(process_subpattern(c) for c in src.sources) 839b8e80941Smrg item = get_item(opcode, children, pattern) 840b8e80941Smrg for i, child in enumerate(children): 841b8e80941Smrg child.parent_ops.add(opcode) 842b8e80941Smrg return item 843b8e80941Smrg 844b8e80941Smrg for i, pattern in enumerate(self.patterns): 845b8e80941Smrg process_subpattern(pattern, i) 846b8e80941Smrg 847b8e80941Smrg def _build_table(self): 848b8e80941Smrg """This is the core algorithm which builds up the transition table. It 849b8e80941Smrg is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 850b8e80941Smrg Comp_a and Filt_{a,i} using integers to identify match sets." It 851b8e80941Smrg simultaneously builds up a list of all possible "match sets" or 852b8e80941Smrg "states", where each match set represents the set of Item's that match a 853b8e80941Smrg given instruction, and builds up the transition table between states. 854b8e80941Smrg """ 855b8e80941Smrg # Map from opcode + filtered state indices to transitioned state. 856b8e80941Smrg self.table = defaultdict(dict) 857b8e80941Smrg # Bijection from state to index. q in the original algorithm is 858b8e80941Smrg # len(self.states) 859b8e80941Smrg self.states = self.IndexMap() 860b8e80941Smrg # List of pattern matches for each state index. 861b8e80941Smrg self.state_patterns = [] 862b8e80941Smrg # Map from state index to filtered state index for each opcode. 863b8e80941Smrg self.filter = defaultdict(list) 864b8e80941Smrg # Bijections from filtered state to filtered state index for each 865b8e80941Smrg # opcode, called the "representor sets" in the original algorithm. 866b8e80941Smrg # q_{a,j} in the original algorithm is len(self.rep[op]). 867b8e80941Smrg self.rep = defaultdict(self.IndexMap) 868b8e80941Smrg 869b8e80941Smrg # Everything in self.states with a index at least worklist_index is part 870b8e80941Smrg # of the worklist of newly created states. There is also a worklist of 871b8e80941Smrg # newly fitered states for each opcode, for which worklist_indices 872b8e80941Smrg # serves a similar purpose. worklist_index corresponds to p in the 873b8e80941Smrg # original algorithm, while worklist_indices is p_{a,j} (although since 874b8e80941Smrg # we only filter by opcode/symbol, it's really just p_a). 875b8e80941Smrg self.worklist_index = 0 876b8e80941Smrg worklist_indices = defaultdict(lambda: 0) 877b8e80941Smrg 878b8e80941Smrg # This is the set of opcodes for which the filtered worklist is non-empty. 879b8e80941Smrg # It's used to avoid scanning opcodes for which there is nothing to 880b8e80941Smrg # process when building the transition table. It corresponds to new_a in 881b8e80941Smrg # the original algorithm. 882b8e80941Smrg new_opcodes = self.IndexMap() 883b8e80941Smrg 884b8e80941Smrg # Process states on the global worklist, filtering them for each opcode, 885b8e80941Smrg # updating the filter tables, and updating the filtered worklists if any 886b8e80941Smrg # new filtered states are found. Similar to ComputeRepresenterSets() in 887b8e80941Smrg # the original algorithm, although that only processes a single state. 888b8e80941Smrg def process_new_states(): 889b8e80941Smrg while self.worklist_index < len(self.states): 890b8e80941Smrg state = self.states[self.worklist_index] 891b8e80941Smrg 892b8e80941Smrg # Calculate pattern matches for this state. Each pattern is 893b8e80941Smrg # assigned to a unique item, so we don't have to worry about 894b8e80941Smrg # deduplicating them here. However, we do have to sort them so 895b8e80941Smrg # that they're visited at runtime in the order they're specified 896b8e80941Smrg # in the source. 897b8e80941Smrg patterns = list(sorted(p for item in state for p in item.patterns)) 898b8e80941Smrg assert len(self.state_patterns) == self.worklist_index 899b8e80941Smrg self.state_patterns.append(patterns) 900b8e80941Smrg 901b8e80941Smrg # calculate filter table for this state, and update filtered 902b8e80941Smrg # worklists. 903b8e80941Smrg for op in self.opcodes: 904b8e80941Smrg filt = self.filter[op] 905b8e80941Smrg rep = self.rep[op] 906b8e80941Smrg filtered = frozenset(item for item in state if \ 907b8e80941Smrg op in item.parent_ops) 908b8e80941Smrg if filtered in rep: 909b8e80941Smrg rep_index = rep.index(filtered) 910b8e80941Smrg else: 911b8e80941Smrg rep_index = rep.add(filtered) 912b8e80941Smrg new_opcodes.add(op) 913b8e80941Smrg assert len(filt) == self.worklist_index 914b8e80941Smrg filt.append(rep_index) 915b8e80941Smrg self.worklist_index += 1 916b8e80941Smrg 917b8e80941Smrg # There are two start states: one which can only match as a wildcard, 918b8e80941Smrg # and one which can match as a wildcard or constant. These will be the 919b8e80941Smrg # states of intrinsics/other instructions and load_const instructions, 920b8e80941Smrg # respectively. The indices of these must match the definitions of 921b8e80941Smrg # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 922b8e80941Smrg # initialize things correctly. 923b8e80941Smrg self.states.add(frozenset((self.wildcard,))) 924b8e80941Smrg self.states.add(frozenset((self.const,self.wildcard))) 925b8e80941Smrg process_new_states() 926b8e80941Smrg 927b8e80941Smrg while len(new_opcodes) > 0: 928b8e80941Smrg for op in new_opcodes: 929b8e80941Smrg rep = self.rep[op] 930b8e80941Smrg table = self.table[op] 931b8e80941Smrg op_worklist_index = worklist_indices[op] 932b8e80941Smrg if op in conv_opcode_types: 933b8e80941Smrg num_srcs = 1 934b8e80941Smrg else: 935b8e80941Smrg num_srcs = opcodes[op].num_inputs 936b8e80941Smrg 937b8e80941Smrg # Iterate over all possible source combinations where at least one 938b8e80941Smrg # is on the worklist. 939b8e80941Smrg for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 940b8e80941Smrg if all(src_idx < op_worklist_index for src_idx in src_indices): 941b8e80941Smrg continue 942b8e80941Smrg 943b8e80941Smrg srcs = tuple(rep[src_idx] for src_idx in src_indices) 944b8e80941Smrg 945b8e80941Smrg # Try all possible pairings of source items and add the 946b8e80941Smrg # corresponding parent items. This is Comp_a from the paper. 947b8e80941Smrg parent = set(self.items[op, item_srcs] for item_srcs in 948b8e80941Smrg itertools.product(*srcs) if (op, item_srcs) in self.items) 949b8e80941Smrg 950b8e80941Smrg # We could always start matching something else with a 951b8e80941Smrg # wildcard. This is Cl from the paper. 952b8e80941Smrg parent.add(self.wildcard) 953b8e80941Smrg 954b8e80941Smrg table[src_indices] = self.states.add(frozenset(parent)) 955b8e80941Smrg worklist_indices[op] = len(rep) 956b8e80941Smrg new_opcodes.clear() 957b8e80941Smrg process_new_states() 958b8e80941Smrg 959b8e80941Smrg_algebraic_pass_template = mako.template.Template(""" 960b8e80941Smrg#include "nir.h" 961b8e80941Smrg#include "nir_builder.h" 962b8e80941Smrg#include "nir_search.h" 963b8e80941Smrg#include "nir_search_helpers.h" 964b8e80941Smrg 965b8e80941Smrg#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS 966b8e80941Smrg#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS 967b8e80941Smrg 968b8e80941Smrgstruct transform { 969b8e80941Smrg const nir_search_expression *search; 970b8e80941Smrg const nir_search_value *replace; 971b8e80941Smrg unsigned condition_offset; 972b8e80941Smrg}; 973b8e80941Smrg 974b8e80941Smrgstruct per_op_table { 975b8e80941Smrg const uint16_t *filter; 976b8e80941Smrg unsigned num_filtered_states; 977b8e80941Smrg const uint16_t *table; 978b8e80941Smrg}; 979b8e80941Smrg 980b8e80941Smrg/* Note: these must match the start states created in 981b8e80941Smrg * TreeAutomaton._build_table() 982b8e80941Smrg */ 983b8e80941Smrg 984b8e80941Smrg/* WILDCARD_STATE = 0 is set by zeroing the state array */ 985b8e80941Smrgstatic const uint16_t CONST_STATE = 1; 986b8e80941Smrg 987b8e80941Smrg#endif 988b8e80941Smrg 989b8e80941Smrg<% cache = {} %> 990b8e80941Smrg% for xform in xforms: 991b8e80941Smrg ${xform.search.render(cache)} 992b8e80941Smrg ${xform.replace.render(cache)} 993b8e80941Smrg% endfor 994b8e80941Smrg 995b8e80941Smrg% for state_id, state_xforms in enumerate(automaton.state_patterns): 996b8e80941Smrg% if state_xforms: # avoid emitting a 0-length array for MSVC 997b8e80941Smrgstatic const struct transform ${pass_name}_state${state_id}_xforms[] = { 998b8e80941Smrg% for i in state_xforms: 999b8e80941Smrg { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} }, 1000b8e80941Smrg% endfor 1001b8e80941Smrg}; 1002b8e80941Smrg% endif 1003b8e80941Smrg% endfor 1004b8e80941Smrg 1005b8e80941Smrgstatic const struct per_op_table ${pass_name}_table[nir_num_search_ops] = { 1006b8e80941Smrg% for op in automaton.opcodes: 1007b8e80941Smrg [${get_c_opcode(op)}] = { 1008b8e80941Smrg .filter = (uint16_t []) { 1009b8e80941Smrg % for e in automaton.filter[op]: 1010b8e80941Smrg ${e}, 1011b8e80941Smrg % endfor 1012b8e80941Smrg }, 1013b8e80941Smrg <% 1014b8e80941Smrg num_filtered = len(automaton.rep[op]) 1015b8e80941Smrg %> 1016b8e80941Smrg .num_filtered_states = ${num_filtered}, 1017b8e80941Smrg .table = (uint16_t []) { 1018b8e80941Smrg <% 1019b8e80941Smrg num_srcs = len(next(iter(automaton.table[op]))) 1020b8e80941Smrg %> 1021b8e80941Smrg % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1022b8e80941Smrg ${automaton.table[op][indices]}, 1023b8e80941Smrg % endfor 1024b8e80941Smrg }, 1025b8e80941Smrg }, 1026b8e80941Smrg% endfor 1027b8e80941Smrg}; 1028b8e80941Smrg 1029b8e80941Smrgstatic void 1030b8e80941Smrg${pass_name}_pre_block(nir_block *block, uint16_t *states) 1031b8e80941Smrg{ 1032b8e80941Smrg nir_foreach_instr(instr, block) { 1033b8e80941Smrg switch (instr->type) { 1034b8e80941Smrg case nir_instr_type_alu: { 1035b8e80941Smrg nir_alu_instr *alu = nir_instr_as_alu(instr); 1036b8e80941Smrg nir_op op = alu->op; 1037b8e80941Smrg uint16_t search_op = nir_search_op_for_nir_op(op); 1038b8e80941Smrg const struct per_op_table *tbl = &${pass_name}_table[search_op]; 1039b8e80941Smrg if (tbl->num_filtered_states == 0) 1040b8e80941Smrg continue; 1041b8e80941Smrg 1042b8e80941Smrg /* Calculate the index into the transition table. Note the index 1043b8e80941Smrg * calculated must match the iteration order of Python's 1044b8e80941Smrg * itertools.product(), which was used to emit the transition 1045b8e80941Smrg * table. 1046b8e80941Smrg */ 1047b8e80941Smrg uint16_t index = 0; 1048b8e80941Smrg for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { 1049b8e80941Smrg index *= tbl->num_filtered_states; 1050b8e80941Smrg index += tbl->filter[states[alu->src[i].src.ssa->index]]; 1051b8e80941Smrg } 1052b8e80941Smrg states[alu->dest.dest.ssa.index] = tbl->table[index]; 1053b8e80941Smrg break; 1054b8e80941Smrg } 1055b8e80941Smrg 1056b8e80941Smrg case nir_instr_type_load_const: { 1057b8e80941Smrg nir_load_const_instr *load_const = nir_instr_as_load_const(instr); 1058b8e80941Smrg states[load_const->def.index] = CONST_STATE; 1059b8e80941Smrg break; 1060b8e80941Smrg } 1061b8e80941Smrg 1062b8e80941Smrg default: 1063b8e80941Smrg break; 1064b8e80941Smrg } 1065b8e80941Smrg } 1066b8e80941Smrg} 1067b8e80941Smrg 1068b8e80941Smrgstatic bool 1069b8e80941Smrg${pass_name}_block(nir_builder *build, nir_block *block, 1070b8e80941Smrg const uint16_t *states, const bool *condition_flags) 1071b8e80941Smrg{ 1072b8e80941Smrg bool progress = false; 1073b8e80941Smrg 1074b8e80941Smrg nir_foreach_instr_reverse_safe(instr, block) { 1075b8e80941Smrg if (instr->type != nir_instr_type_alu) 1076b8e80941Smrg continue; 1077b8e80941Smrg 1078b8e80941Smrg nir_alu_instr *alu = nir_instr_as_alu(instr); 1079b8e80941Smrg if (!alu->dest.dest.is_ssa) 1080b8e80941Smrg continue; 1081b8e80941Smrg 1082b8e80941Smrg switch (states[alu->dest.dest.ssa.index]) { 1083b8e80941Smrg% for i in range(len(automaton.state_patterns)): 1084b8e80941Smrg case ${i}: 1085b8e80941Smrg % if automaton.state_patterns[i]: 1086b8e80941Smrg for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) { 1087b8e80941Smrg const struct transform *xform = &${pass_name}_state${i}_xforms[i]; 1088b8e80941Smrg if (condition_flags[xform->condition_offset] && 1089b8e80941Smrg nir_replace_instr(build, alu, xform->search, xform->replace)) { 1090b8e80941Smrg progress = true; 1091b8e80941Smrg break; 1092b8e80941Smrg } 1093b8e80941Smrg } 1094b8e80941Smrg % endif 1095b8e80941Smrg break; 1096b8e80941Smrg% endfor 1097b8e80941Smrg default: assert(0); 1098b8e80941Smrg } 1099b8e80941Smrg } 1100b8e80941Smrg 1101b8e80941Smrg return progress; 1102b8e80941Smrg} 1103b8e80941Smrg 1104b8e80941Smrgstatic bool 1105b8e80941Smrg${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags) 1106b8e80941Smrg{ 1107b8e80941Smrg bool progress = false; 1108b8e80941Smrg 1109b8e80941Smrg nir_builder build; 1110b8e80941Smrg nir_builder_init(&build, impl); 1111b8e80941Smrg 1112b8e80941Smrg /* Note: it's important here that we're allocating a zeroed array, since 1113b8e80941Smrg * state 0 is the default state, which means we don't have to visit 1114b8e80941Smrg * anything other than constants and ALU instructions. 1115b8e80941Smrg */ 1116b8e80941Smrg uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states)); 1117b8e80941Smrg 1118b8e80941Smrg nir_foreach_block(block, impl) { 1119b8e80941Smrg ${pass_name}_pre_block(block, states); 1120b8e80941Smrg } 1121b8e80941Smrg 1122b8e80941Smrg nir_foreach_block_reverse(block, impl) { 1123b8e80941Smrg progress |= ${pass_name}_block(&build, block, states, condition_flags); 1124b8e80941Smrg } 1125b8e80941Smrg 1126b8e80941Smrg free(states); 1127b8e80941Smrg 1128b8e80941Smrg if (progress) { 1129b8e80941Smrg nir_metadata_preserve(impl, nir_metadata_block_index | 1130b8e80941Smrg nir_metadata_dominance); 1131b8e80941Smrg } else { 1132b8e80941Smrg#ifndef NDEBUG 1133b8e80941Smrg impl->valid_metadata &= ~nir_metadata_not_properly_reset; 1134b8e80941Smrg#endif 1135b8e80941Smrg } 1136b8e80941Smrg 1137b8e80941Smrg return progress; 1138b8e80941Smrg} 1139b8e80941Smrg 1140b8e80941Smrg 1141b8e80941Smrgbool 1142b8e80941Smrg${pass_name}(nir_shader *shader) 1143b8e80941Smrg{ 1144b8e80941Smrg bool progress = false; 1145b8e80941Smrg bool condition_flags[${len(condition_list)}]; 1146b8e80941Smrg const nir_shader_compiler_options *options = shader->options; 1147b8e80941Smrg const shader_info *info = &shader->info; 1148b8e80941Smrg (void) options; 1149b8e80941Smrg (void) info; 1150b8e80941Smrg 1151b8e80941Smrg % for index, condition in enumerate(condition_list): 1152b8e80941Smrg condition_flags[${index}] = ${condition}; 1153b8e80941Smrg % endfor 1154b8e80941Smrg 1155b8e80941Smrg nir_foreach_function(function, shader) { 1156b8e80941Smrg if (function->impl) 1157b8e80941Smrg progress |= ${pass_name}_impl(function->impl, condition_flags); 1158b8e80941Smrg } 1159b8e80941Smrg 1160b8e80941Smrg return progress; 1161b8e80941Smrg} 1162b8e80941Smrg""") 1163b8e80941Smrg 1164b8e80941Smrg 1165b8e80941Smrg 1166b8e80941Smrgclass AlgebraicPass(object): 1167b8e80941Smrg def __init__(self, pass_name, transforms): 1168b8e80941Smrg self.xforms = [] 1169b8e80941Smrg self.opcode_xforms = defaultdict(lambda : []) 1170b8e80941Smrg self.pass_name = pass_name 1171b8e80941Smrg 1172b8e80941Smrg error = False 1173b8e80941Smrg 1174b8e80941Smrg for xform in transforms: 1175b8e80941Smrg if not isinstance(xform, SearchAndReplace): 1176b8e80941Smrg try: 1177b8e80941Smrg xform = SearchAndReplace(xform) 1178b8e80941Smrg except: 1179b8e80941Smrg print("Failed to parse transformation:", file=sys.stderr) 1180b8e80941Smrg print(" " + str(xform), file=sys.stderr) 1181b8e80941Smrg traceback.print_exc(file=sys.stderr) 1182b8e80941Smrg print('', file=sys.stderr) 1183b8e80941Smrg error = True 1184b8e80941Smrg continue 1185b8e80941Smrg 1186b8e80941Smrg self.xforms.append(xform) 1187b8e80941Smrg if xform.search.opcode in conv_opcode_types: 1188b8e80941Smrg dst_type = conv_opcode_types[xform.search.opcode] 1189b8e80941Smrg for size in type_sizes(dst_type): 1190b8e80941Smrg sized_opcode = xform.search.opcode + str(size) 1191b8e80941Smrg self.opcode_xforms[sized_opcode].append(xform) 1192b8e80941Smrg else: 1193b8e80941Smrg self.opcode_xforms[xform.search.opcode].append(xform) 1194b8e80941Smrg 1195b8e80941Smrg self.automaton = TreeAutomaton(self.xforms) 1196b8e80941Smrg 1197b8e80941Smrg if error: 1198b8e80941Smrg sys.exit(1) 1199b8e80941Smrg 1200b8e80941Smrg 1201b8e80941Smrg def render(self): 1202b8e80941Smrg return _algebraic_pass_template.render(pass_name=self.pass_name, 1203b8e80941Smrg xforms=self.xforms, 1204b8e80941Smrg opcode_xforms=self.opcode_xforms, 1205b8e80941Smrg condition_list=condition_list, 1206b8e80941Smrg automaton=self.automaton, 1207b8e80941Smrg get_c_opcode=get_c_opcode, 1208b8e80941Smrg itertools=itertools) 1209