1 1.1 mrg # Copyright 2016-2017 Tobias Grosser 2 1.1 mrg # 3 1.1 mrg # Use of this software is governed by the MIT license 4 1.1 mrg # 5 1.1 mrg # Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich 6 1.1 mrg 7 1.1 mrg import sys 8 1.1 mrg import isl 9 1.1 mrg 10 1.1 mrg # Test that isl objects can be constructed. 11 1.1 mrg # 12 1.1 mrg # This tests: 13 1.1 mrg # - construction from a string 14 1.1 mrg # - construction from an integer 15 1.1 mrg # - static constructor without a parameter 16 1.1 mrg # - conversion construction 17 1.1 mrg # - construction of empty union set 18 1.1 mrg # 19 1.1 mrg # The tests to construct from integers and strings cover functionality that 20 1.1 mrg # is also tested in the parameter type tests, but here the presence of 21 1.1 mrg # multiple overloaded constructors and overload resolution is tested. 22 1.1 mrg # 23 1.1 mrg def test_constructors(): 24 1.1 mrg zero1 = isl.val("0") 25 1.1 mrg assert(zero1.is_zero()) 26 1.1 mrg 27 1.1 mrg zero2 = isl.val(0) 28 1.1 mrg assert(zero2.is_zero()) 29 1.1 mrg 30 1.1 mrg zero3 = isl.val.zero() 31 1.1 mrg assert(zero3.is_zero()) 32 1.1 mrg 33 1.1 mrg bs = isl.basic_set("{ [1] }") 34 1.1 mrg result = isl.set("{ [1] }") 35 1.1 mrg s = isl.set(bs) 36 1.1 mrg assert(s.is_equal(result)) 37 1.1 mrg 38 1.1 mrg us = isl.union_set("{ A[1]; B[2, 3] }") 39 1.1 mrg empty = isl.union_set.empty() 40 1.1 mrg assert(us.is_equal(us.union(empty))) 41 1.1 mrg 42 1.1 mrg # Test integer function parameters for a particular integer value. 43 1.1 mrg # 44 1.1 mrg def test_int(i): 45 1.1 mrg val_int = isl.val(i) 46 1.1 mrg val_str = isl.val(str(i)) 47 1.1 mrg assert(val_int.eq(val_str)) 48 1.1 mrg 49 1.1 mrg # Test integer function parameters. 50 1.1 mrg # 51 1.1 mrg # Verify that extreme values and zero work. 52 1.1 mrg # 53 1.1 mrg def test_parameters_int(): 54 1.1 mrg test_int(sys.maxsize) 55 1.1 mrg test_int(-sys.maxsize - 1) 56 1.1 mrg test_int(0) 57 1.1 mrg 58 1.1 mrg # Test isl objects parameters. 59 1.1 mrg # 60 1.1 mrg # Verify that isl objects can be passed as lvalue and rvalue parameters. 61 1.1 mrg # Also verify that isl object parameters are automatically type converted if 62 1.1 mrg # there is an inheritance relation. Finally, test function calls without 63 1.1 mrg # any additional parameters, apart from the isl object on which 64 1.1 mrg # the method is called. 65 1.1 mrg # 66 1.1 mrg def test_parameters_obj(): 67 1.1 mrg a = isl.set("{ [0] }") 68 1.1 mrg b = isl.set("{ [1] }") 69 1.1 mrg c = isl.set("{ [2] }") 70 1.1 mrg expected = isl.set("{ [i] : 0 <= i <= 2 }") 71 1.1 mrg 72 1.1 mrg tmp = a.union(b) 73 1.1 mrg res_lvalue_param = tmp.union(c) 74 1.1 mrg assert(res_lvalue_param.is_equal(expected)) 75 1.1 mrg 76 1.1 mrg res_rvalue_param = a.union(b).union(c) 77 1.1 mrg assert(res_rvalue_param.is_equal(expected)) 78 1.1 mrg 79 1.1 mrg a2 = isl.basic_set("{ [0] }") 80 1.1 mrg assert(a.is_equal(a2)) 81 1.1 mrg 82 1.1 mrg two = isl.val(2) 83 1.1 mrg half = isl.val("1/2") 84 1.1 mrg res_only_this_param = two.inv() 85 1.1 mrg assert(res_only_this_param.eq(half)) 86 1.1 mrg 87 1.1 mrg # Test different kinds of parameters to be passed to functions. 88 1.1 mrg # 89 1.1 mrg # This includes integer and isl object parameters. 90 1.1 mrg # 91 1.1 mrg def test_parameters(): 92 1.1 mrg test_parameters_int() 93 1.1 mrg test_parameters_obj() 94 1.1 mrg 95 1.1 mrg # Test that isl objects are returned correctly. 96 1.1 mrg # 97 1.1 mrg # This only tests that after combining two objects, the result is successfully 98 1.1 mrg # returned. 99 1.1 mrg # 100 1.1 mrg def test_return_obj(): 101 1.1 mrg one = isl.val("1") 102 1.1 mrg two = isl.val("2") 103 1.1 mrg three = isl.val("3") 104 1.1 mrg 105 1.1 mrg res = one.add(two) 106 1.1 mrg 107 1.1 mrg assert(res.eq(three)) 108 1.1 mrg 109 1.1 mrg # Test that integer values are returned correctly. 110 1.1 mrg # 111 1.1 mrg def test_return_int(): 112 1.1 mrg one = isl.val("1") 113 1.1 mrg neg_one = isl.val("-1") 114 1.1 mrg zero = isl.val("0") 115 1.1 mrg 116 1.1 mrg assert(one.sgn() > 0) 117 1.1 mrg assert(neg_one.sgn() < 0) 118 1.1 mrg assert(zero.sgn() == 0) 119 1.1 mrg 120 1.1 mrg # Test that isl_bool values are returned correctly. 121 1.1 mrg # 122 1.1 mrg # In particular, check the conversion to bool in case of true and false. 123 1.1 mrg # 124 1.1 mrg def test_return_bool(): 125 1.1 mrg empty = isl.set("{ : false }") 126 1.1 mrg univ = isl.set("{ : }") 127 1.1 mrg 128 1.1 mrg b_true = empty.is_empty() 129 1.1 mrg b_false = univ.is_empty() 130 1.1 mrg 131 1.1 mrg assert(b_true) 132 1.1 mrg assert(not b_false) 133 1.1 mrg 134 1.1 mrg # Test that strings are returned correctly. 135 1.1 mrg # Do so by calling overloaded isl.ast_build.from_expr methods. 136 1.1 mrg # 137 1.1 mrg def test_return_string(): 138 1.1 mrg context = isl.set("[n] -> { : }") 139 1.1 mrg build = isl.ast_build.from_context(context) 140 1.1 mrg pw_aff = isl.pw_aff("[n] -> { [n] }") 141 1.1 mrg set = isl.set("[n] -> { : n >= 0 }") 142 1.1 mrg 143 1.1 mrg expr = build.expr_from(pw_aff) 144 1.1 mrg expected_string = "n" 145 1.1 mrg assert(expected_string == expr.to_C_str()) 146 1.1 mrg 147 1.1 mrg expr = build.expr_from(set) 148 1.1 mrg expected_string = "n >= 0" 149 1.1 mrg assert(expected_string == expr.to_C_str()) 150 1.1 mrg 151 1.1 mrg # Test that return values are handled correctly. 152 1.1 mrg # 153 1.1 mrg # Test that isl objects, integers, boolean values, and strings are 154 1.1 mrg # returned correctly. 155 1.1 mrg # 156 1.1 mrg def test_return(): 157 1.1 mrg test_return_obj() 158 1.1 mrg test_return_int() 159 1.1 mrg test_return_bool() 160 1.1 mrg test_return_string() 161 1.1 mrg 162 1.1 mrg # A class that is used to test isl.id.user. 163 1.1 mrg # 164 1.1 mrg class S: 165 1.1 mrg def __init__(self): 166 1.1 mrg self.value = 42 167 1.1 mrg 168 1.1 mrg # Test isl.id.user. 169 1.1 mrg # 170 1.1 mrg # In particular, check that the object attached to an identifier 171 1.1 mrg # can be retrieved again. 172 1.1 mrg # 173 1.1 mrg def test_user(): 174 1.1 mrg id = isl.id("test", 5) 175 1.1 mrg id2 = isl.id("test2") 176 1.1 mrg id3 = isl.id("S", S()) 177 1.1 mrg assert id.user() == 5, f"unexpected user object {id.user()}" 178 1.1 mrg assert id2.user() is None, f"unexpected user object {id2.user()}" 179 1.1 mrg s = id3.user() 180 1.1 mrg assert isinstance(s, S), f"unexpected user object {s}" 181 1.1 mrg assert s.value == 42, f"unexpected user object {s}" 182 1.1 mrg 183 1.1 mrg # Test that foreach functions are modeled correctly. 184 1.1 mrg # 185 1.1 mrg # Verify that closures are correctly called as callback of a 'foreach' 186 1.1 mrg # function and that variables captured by the closure work correctly. Also 187 1.1 mrg # check that the foreach function handles exceptions thrown from 188 1.1 mrg # the closure and that it propagates the exception. 189 1.1 mrg # 190 1.1 mrg def test_foreach(): 191 1.1 mrg s = isl.set("{ [0]; [1]; [2] }") 192 1.1 mrg 193 1.1 mrg list = [] 194 1.1 mrg def add(bs): 195 1.1 mrg list.append(bs) 196 1.1 mrg s.foreach_basic_set(add) 197 1.1 mrg 198 1.1 mrg assert(len(list) == 3) 199 1.1 mrg assert(list[0].is_subset(s)) 200 1.1 mrg assert(list[1].is_subset(s)) 201 1.1 mrg assert(list[2].is_subset(s)) 202 1.1 mrg assert(not list[0].is_equal(list[1])) 203 1.1 mrg assert(not list[0].is_equal(list[2])) 204 1.1 mrg assert(not list[1].is_equal(list[2])) 205 1.1 mrg 206 1.1 mrg def fail(bs): 207 1.1 mrg raise Exception("fail") 208 1.1 mrg 209 1.1 mrg caught = False 210 1.1 mrg try: 211 1.1 mrg s.foreach_basic_set(fail) 212 1.1 mrg except: 213 1.1 mrg caught = True 214 1.1 mrg assert(caught) 215 1.1 mrg 216 1.1 mrg # Test the functionality of "foreach_scc" functions. 217 1.1 mrg # 218 1.1 mrg # In particular, test it on a list of elements that can be completely sorted 219 1.1 mrg # but where two of the elements ("a" and "b") are incomparable. 220 1.1 mrg # 221 1.1 mrg def test_foreach_scc(): 222 1.1 mrg list = isl.id_list(3) 223 1.1 mrg sorted = [isl.id_list(3)] 224 1.1 mrg data = { 225 1.1 mrg 'a' : isl.map("{ [0] -> [1] }"), 226 1.1 mrg 'b' : isl.map("{ [1] -> [0] }"), 227 1.1 mrg 'c' : isl.map("{ [i = 0:1] -> [i] }"), 228 1.1 mrg } 229 1.1 mrg for k, v in data.items(): 230 1.1 mrg list = list.add(k) 231 1.1 mrg id = data['a'].space().domain().identity_multi_pw_aff_on_domain() 232 1.1 mrg def follows(a, b): 233 1.1 mrg map = data[b.name()].apply_domain(data[a.name()]) 234 1.1 mrg return not map.lex_ge_at(id).is_empty() 235 1.1 mrg 236 1.1 mrg def add_single(scc): 237 1.1 mrg assert(scc.size() == 1) 238 1.1 mrg sorted[0] = sorted[0].concat(scc) 239 1.1 mrg 240 1.1 mrg list.foreach_scc(follows, add_single) 241 1.1 mrg assert(sorted[0].size() == 3) 242 1.1 mrg assert(sorted[0].at(0).name() == "b") 243 1.1 mrg assert(sorted[0].at(1).name() == "c") 244 1.1 mrg assert(sorted[0].at(2).name() == "a") 245 1.1 mrg 246 1.1 mrg # Test the functionality of "every" functions. 247 1.1 mrg # 248 1.1 mrg # In particular, test the generic functionality and 249 1.1 mrg # test that exceptions are properly propagated. 250 1.1 mrg # 251 1.1 mrg def test_every(): 252 1.1 mrg us = isl.union_set("{ A[i]; B[j] }") 253 1.1 mrg 254 1.1 mrg def is_empty(s): 255 1.1 mrg return s.is_empty() 256 1.1 mrg assert(not us.every_set(is_empty)) 257 1.1 mrg 258 1.1 mrg def is_non_empty(s): 259 1.1 mrg return not s.is_empty() 260 1.1 mrg assert(us.every_set(is_non_empty)) 261 1.1 mrg 262 1.1 mrg def in_A(s): 263 1.1 mrg return s.is_subset(isl.set("{ A[x] }")) 264 1.1 mrg assert(not us.every_set(in_A)) 265 1.1 mrg 266 1.1 mrg def not_in_A(s): 267 1.1 mrg return not s.is_subset(isl.set("{ A[x] }")) 268 1.1 mrg assert(not us.every_set(not_in_A)) 269 1.1 mrg 270 1.1 mrg def fail(s): 271 1.1 mrg raise Exception("fail") 272 1.1 mrg 273 1.1 mrg caught = False 274 1.1 mrg try: 275 1.1 mrg us.ever_set(fail) 276 1.1 mrg except: 277 1.1 mrg caught = True 278 1.1 mrg assert(caught) 279 1.1 mrg 280 1.1 mrg # Check basic construction of spaces. 281 1.1 mrg # 282 1.1 mrg def test_space(): 283 1.1 mrg unit = isl.space.unit() 284 1.1 mrg set_space = unit.add_named_tuple("A", 3) 285 1.1 mrg map_space = set_space.add_named_tuple("B", 2) 286 1.1 mrg 287 1.1 mrg set = isl.set.universe(set_space) 288 1.1 mrg map = isl.map.universe(map_space) 289 1.1 mrg assert(set.is_equal(isl.set("{ A[*,*,*] }"))) 290 1.1 mrg assert(map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }"))) 291 1.1 mrg 292 1.1 mrg # Construct a simple schedule tree with an outer sequence node and 293 1.1 mrg # a single-dimensional band node in each branch, with one of them 294 1.1 mrg # marked coincident. 295 1.1 mrg # 296 1.1 mrg def construct_schedule_tree(): 297 1.1 mrg A = isl.union_set("{ A[i] : 0 <= i < 10 }") 298 1.1 mrg B = isl.union_set("{ B[i] : 0 <= i < 20 }") 299 1.1 mrg 300 1.1 mrg node = isl.schedule_node.from_domain(A.union(B)) 301 1.1 mrg node = node.child(0) 302 1.1 mrg 303 1.1 mrg filters = isl.union_set_list(A).add(B) 304 1.1 mrg node = node.insert_sequence(filters) 305 1.1 mrg 306 1.1 mrg f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]") 307 1.1 mrg node = node.child(0) 308 1.1 mrg node = node.child(0) 309 1.1 mrg node = node.insert_partial_schedule(f_A) 310 1.1 mrg node = node.member_set_coincident(0, True) 311 1.1 mrg node = node.ancestor(2) 312 1.1 mrg 313 1.1 mrg f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]") 314 1.1 mrg node = node.child(1) 315 1.1 mrg node = node.child(0) 316 1.1 mrg node = node.insert_partial_schedule(f_B) 317 1.1 mrg node = node.ancestor(2) 318 1.1 mrg 319 1.1 mrg return node.schedule() 320 1.1 mrg 321 1.1 mrg # Test basic schedule tree functionality. 322 1.1 mrg # 323 1.1 mrg # In particular, create a simple schedule tree and 324 1.1 mrg # - check that the root node is a domain node 325 1.1 mrg # - test map_descendant_bottom_up 326 1.1 mrg # - test foreach_descendant_top_down 327 1.1 mrg # - test every_descendant 328 1.1 mrg # 329 1.1 mrg def test_schedule_tree(): 330 1.1 mrg schedule = construct_schedule_tree() 331 1.1 mrg root = schedule.root() 332 1.1 mrg 333 1.1 mrg assert(type(root) == isl.schedule_node_domain) 334 1.1 mrg 335 1.1 mrg count = [0] 336 1.1 mrg def inc_count(node): 337 1.1 mrg count[0] += 1 338 1.1 mrg return node 339 1.1 mrg root = root.map_descendant_bottom_up(inc_count) 340 1.1 mrg assert(count[0] == 8) 341 1.1 mrg 342 1.1 mrg def fail_map(node): 343 1.1 mrg raise Exception("fail") 344 1.1 mrg return node 345 1.1 mrg caught = False 346 1.1 mrg try: 347 1.1 mrg root.map_descendant_bottom_up(fail_map) 348 1.1 mrg except: 349 1.1 mrg caught = True 350 1.1 mrg assert(caught) 351 1.1 mrg 352 1.1 mrg count = [0] 353 1.1 mrg def inc_count(node): 354 1.1 mrg count[0] += 1 355 1.1 mrg return True 356 1.1 mrg root.foreach_descendant_top_down(inc_count) 357 1.1 mrg assert(count[0] == 8) 358 1.1 mrg 359 1.1 mrg count = [0] 360 1.1 mrg def inc_count(node): 361 1.1 mrg count[0] += 1 362 1.1 mrg return False 363 1.1 mrg root.foreach_descendant_top_down(inc_count) 364 1.1 mrg assert(count[0] == 1) 365 1.1 mrg 366 1.1 mrg def is_not_domain(node): 367 1.1 mrg return type(node) != isl.schedule_node_domain 368 1.1 mrg assert(root.child(0).every_descendant(is_not_domain)) 369 1.1 mrg assert(not root.every_descendant(is_not_domain)) 370 1.1 mrg 371 1.1 mrg def fail(node): 372 1.1 mrg raise Exception("fail") 373 1.1 mrg caught = False 374 1.1 mrg try: 375 1.1 mrg root.every_descendant(fail) 376 1.1 mrg except: 377 1.1 mrg caught = True 378 1.1 mrg assert(caught) 379 1.1 mrg 380 1.1 mrg domain = root.domain() 381 1.1 mrg filters = [isl.union_set("{}")] 382 1.1 mrg def collect_filters(node): 383 1.1 mrg if type(node) == isl.schedule_node_filter: 384 1.1 mrg filters[0] = filters[0].union(node.filter()) 385 1.1 mrg return True 386 1.1 mrg root.every_descendant(collect_filters) 387 1.1 mrg assert(domain.is_equal(filters[0])) 388 1.1 mrg 389 1.1 mrg # Test marking band members for unrolling. 390 1.1 mrg # "schedule" is the schedule created by construct_schedule_tree. 391 1.1 mrg # It schedules two statements, with 10 and 20 instances, respectively. 392 1.1 mrg # Unrolling all band members therefore results in 30 at-domain calls 393 1.1 mrg # by the AST generator. 394 1.1 mrg # 395 1.1 mrg def test_ast_build_unroll(schedule): 396 1.1 mrg root = schedule.root() 397 1.1 mrg def mark_unroll(node): 398 1.1 mrg if type(node) == isl.schedule_node_band: 399 1.1 mrg node = node.member_set_ast_loop_unroll(0) 400 1.1 mrg return node 401 1.1 mrg root = root.map_descendant_bottom_up(mark_unroll) 402 1.1 mrg schedule = root.schedule() 403 1.1 mrg 404 1.1 mrg count_ast = [0] 405 1.1 mrg def inc_count_ast(node, build): 406 1.1 mrg count_ast[0] += 1 407 1.1 mrg return node 408 1.1 mrg 409 1.1 mrg build = isl.ast_build() 410 1.1 mrg build = build.set_at_each_domain(inc_count_ast) 411 1.1 mrg ast = build.node_from(schedule) 412 1.1 mrg assert(count_ast[0] == 30) 413 1.1 mrg 414 1.1 mrg # Test basic AST generation from a schedule tree. 415 1.1 mrg # 416 1.1 mrg # In particular, create a simple schedule tree and 417 1.1 mrg # - generate an AST from the schedule tree 418 1.1 mrg # - test at_each_domain 419 1.1 mrg # - test unrolling 420 1.1 mrg # 421 1.1 mrg def test_ast_build(): 422 1.1 mrg schedule = construct_schedule_tree() 423 1.1 mrg 424 1.1 mrg count_ast = [0] 425 1.1 mrg def inc_count_ast(node, build): 426 1.1 mrg count_ast[0] += 1 427 1.1 mrg return node 428 1.1 mrg 429 1.1 mrg build = isl.ast_build() 430 1.1 mrg build_copy = build.set_at_each_domain(inc_count_ast) 431 1.1 mrg ast = build.node_from(schedule) 432 1.1 mrg assert(count_ast[0] == 0) 433 1.1 mrg count_ast[0] = 0 434 1.1 mrg ast = build_copy.node_from(schedule) 435 1.1 mrg assert(count_ast[0] == 2) 436 1.1 mrg build = build_copy 437 1.1 mrg count_ast[0] = 0 438 1.1 mrg ast = build.node_from(schedule) 439 1.1 mrg assert(count_ast[0] == 2) 440 1.1 mrg 441 1.1 mrg do_fail = True 442 1.1 mrg count_ast_fail = [0] 443 1.1 mrg def fail_inc_count_ast(node, build): 444 1.1 mrg count_ast_fail[0] += 1 445 1.1 mrg if do_fail: 446 1.1 mrg raise Exception("fail") 447 1.1 mrg return node 448 1.1 mrg build = isl.ast_build() 449 1.1 mrg build = build.set_at_each_domain(fail_inc_count_ast) 450 1.1 mrg caught = False 451 1.1 mrg try: 452 1.1 mrg ast = build.node_from(schedule) 453 1.1 mrg except: 454 1.1 mrg caught = True 455 1.1 mrg assert(caught) 456 1.1 mrg assert(count_ast_fail[0] > 0) 457 1.1 mrg build_copy = build 458 1.1 mrg build_copy = build_copy.set_at_each_domain(inc_count_ast) 459 1.1 mrg count_ast[0] = 0 460 1.1 mrg ast = build_copy.node_from(schedule) 461 1.1 mrg assert(count_ast[0] == 2) 462 1.1 mrg count_ast_fail[0] = 0 463 1.1 mrg do_fail = False 464 1.1 mrg ast = build.node_from(schedule) 465 1.1 mrg assert(count_ast_fail[0] == 2) 466 1.1 mrg 467 1.1 mrg test_ast_build_unroll(schedule) 468 1.1 mrg 469 1.1 mrg # Test basic AST expression generation from an affine expression. 470 1.1 mrg # 471 1.1 mrg def test_ast_build_expr(): 472 1.1 mrg pa = isl.pw_aff("[n] -> { [n + 1] }") 473 1.1 mrg build = isl.ast_build.from_context(pa.domain()) 474 1.1 mrg 475 1.1 mrg op = build.expr_from(pa) 476 1.1 mrg assert(type(op) == isl.ast_expr_op_add) 477 1.1 mrg assert(op.n_arg() == 2) 478 1.1 mrg 479 1.1 mrg # Test the isl Python interface 480 1.1 mrg # 481 1.1 mrg # This includes: 482 1.1 mrg # - Object construction 483 1.1 mrg # - Different parameter types 484 1.1 mrg # - Different return types 485 1.1 mrg # - isl.id.user 486 1.1 mrg # - Foreach functions 487 1.1 mrg # - Foreach SCC function 488 1.1 mrg # - Every functions 489 1.1 mrg # - Spaces 490 1.1 mrg # - Schedule trees 491 1.1 mrg # - AST generation 492 1.1 mrg # - AST expression generation 493 1.1 mrg # 494 1.1 mrg test_constructors() 495 1.1 mrg test_parameters() 496 1.1 mrg test_return() 497 1.1 mrg test_user() 498 1.1 mrg test_foreach() 499 1.1 mrg test_foreach_scc() 500 1.1 mrg test_every() 501 1.1 mrg test_space() 502 1.1 mrg test_schedule_tree() 503 1.1 mrg test_ast_build() 504 1.1 mrg test_ast_build_expr() 505