Last active
December 27, 2025 12:01
-
-
Save sueszli/35b6296b9e40d7c11af969ce18476be3 to your computer and use it in GitHub Desktop.
canonicalization patterns
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from xdsl.builder import Builder, InsertPoint | |
| from xdsl.context import Context | |
| from xdsl.dialects import riscv | |
| from xdsl.dialects.builtin import ModuleOp | |
| from xdsl.transforms.canonicalize import CanonicalizePass | |
| from xdsl.printer import Printer | |
| ctx = Context() | |
| ctx.load_dialect(riscv.RISCV) | |
| module = ModuleOp([]) | |
| builder = Builder(InsertPoint.at_end(module.body.blocks[0])) | |
| print("------ OriImmediate") | |
| li_op_1 = builder.insert(riscv.LiOp(2, rd=riscv.Registers.A0)) # a0 = 2 | |
| ori_op = builder.insert(riscv.OriOp(li_op_1, 1, rd=riscv.Registers.A0)) # a0 = a0 | 1 | |
| # = li a0, 3 (2 | 1 = 3) | |
| # 0b10 | 0b01 = 0b11 | |
| builder.insert(riscv.SwOp(ori_op.results[0], ori_op.results[0], 0)) # keep alive, avoid dead code elimination | |
| print("\nbefore:\n") | |
| print(module) | |
| CanonicalizePass().apply(ctx, module) | |
| print("\nafter:\n") | |
| print(module) | |
| print("------ XoriImmediate") | |
| module = ModuleOp([]) | |
| builder = Builder(InsertPoint.at_end(module.body.blocks[0])) | |
| li_op_2 = builder.insert(riscv.LiOp(3, rd=riscv.Registers.A0)) # a0 = 3 | |
| xori_op = builder.insert(riscv.XoriOp(li_op_2, 1, rd=riscv.Registers.A0)) # a0 = a0 ^ 1 | |
| # = li a0, 2 (3 ^ 1 = 2) | |
| # 0b11 ^ 0b01 = 0b10 | |
| builder.insert(riscv.SwOp(xori_op.results[0], xori_op.results[0], 0)) | |
| print("\nbefore:\n") | |
| print(module) | |
| CanonicalizePass().apply(ctx, module) | |
| print("\nafter:\n") | |
| print(module) | |
| """ | |
| ------ OriImmediate | |
| before: | |
| builtin.module { | |
| %0 = riscv.li 2 : !riscv.reg<a0> | |
| %1 = riscv.ori %0, 1 : (!riscv.reg<a0>) -> !riscv.reg<a0> | |
| riscv.sw %1, %1, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> () | |
| } | |
| after: | |
| builtin.module { | |
| %0 = riscv.li 3 : !riscv.reg<a0> | |
| riscv.sw %0, %0, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> () | |
| } | |
| ------ XoriImmediate | |
| before: | |
| builtin.module { | |
| %0 = riscv.li 3 : !riscv.reg<a0> | |
| %1 = riscv.xori %0, 1 : (!riscv.reg<a0>) -> !riscv.reg<a0> | |
| riscv.sw %1, %1, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> () | |
| } | |
| after: | |
| builtin.module { | |
| %0 = riscv.li 2 : !riscv.reg<a0> | |
| riscv.sw %0, %0, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> () | |
| } | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from xdsl.context import Context | |
| from xdsl.dialects import scf, arith, func, builtin | |
| from xdsl.dialects.builtin import ModuleOp | |
| from xdsl.ir import Block, Region | |
| from xdsl.transforms.canonicalize import CanonicalizePass | |
| ctx = Context() | |
| ctx.load_dialect(scf.Scf) | |
| ctx.load_dialect(arith.Arith) | |
| ctx.load_dialect(func.Func) | |
| ctx.load_dialect(builtin.Builtin) | |
| print(f"{'-' * 20} scf.if canonicalization (true condition)") | |
| true_val = arith.ConstantOp.from_int_and_width(1, 1) | |
| val1 = arith.ConstantOp.from_int_and_width(1, 32) # true branch (just return 1) | |
| true_block = Block([val1, scf.YieldOp(val1)]) | |
| true_region = Region(true_block) | |
| val2 = arith.ConstantOp.from_int_and_width(2, 32) # false branch (just return 2) | |
| false_block = Block([val2, scf.YieldOp(val2)]) | |
| false_region = Region(false_block) | |
| if_op = scf.IfOp(true_val, [builtin.i32], true_region, false_region) | |
| func_body = Region(Block([true_val, if_op, func.ReturnOp(if_op)])) | |
| func_op = func.FuncOp("test_true", ([], [builtin.i32]), func_body) | |
| module = ModuleOp([func_op]) | |
| print("\nbefore:\n") | |
| print(module) | |
| CanonicalizePass().apply(ctx, module) | |
| print("\nafter:\n") | |
| print(module) | |
| print(f"{'-' * 20} scf.if canonicalization (false condition)") | |
| false_val = arith.ConstantOp.from_int_and_width(0, 1) | |
| val1_2 = arith.ConstantOp.from_int_and_width(1, 32) # true branch (just return 1) | |
| true_block_2 = Block([val1_2, scf.YieldOp(val1_2)]) | |
| true_region_2 = Region(true_block_2) | |
| val2_2 = arith.ConstantOp.from_int_and_width(2, 32) # false branch (just return 2) | |
| false_block_2 = Block([val2_2, scf.YieldOp(val2_2)]) | |
| false_region_2 = Region(false_block_2) | |
| if_op_2 = scf.IfOp(false_val, [builtin.i32], true_region_2, false_region_2) | |
| func_body_2 = Region(Block([false_val, if_op_2, func.ReturnOp(if_op_2)])) | |
| func_op_2 = func.FuncOp("test_false", ([], [builtin.i32]), func_body_2) | |
| module = ModuleOp([func_op_2]) | |
| print("\nbefore:\n") | |
| print(module) | |
| CanonicalizePass().apply(ctx, module) | |
| print("\nafter:\n") | |
| print(module) | |
| """ | |
| -------------------- scf.if canonicalization (true condition) | |
| before: | |
| builtin.module { | |
| func.func @test_true() -> i32 { | |
| %0 = arith.constant true | |
| %1 = scf.if %0 -> (i32) { | |
| %2 = arith.constant 1 : i32 | |
| scf.yield %2 : i32 | |
| } else { | |
| %3 = arith.constant 2 : i32 | |
| scf.yield %3 : i32 | |
| } | |
| func.return %1 : i32 | |
| } | |
| } | |
| after: | |
| builtin.module { | |
| func.func @test_true() -> i32 { | |
| %0 = arith.constant 1 : i32 | |
| func.return %0 : i32 | |
| } | |
| } | |
| -------------------- scf.if canonicalization (false condition) | |
| before: | |
| builtin.module { | |
| func.func @test_false() -> i32 { | |
| %0 = arith.constant false | |
| %1 = scf.if %0 -> (i32) { | |
| %2 = arith.constant 1 : i32 | |
| scf.yield %2 : i32 | |
| } else { | |
| %3 = arith.constant 2 : i32 | |
| scf.yield %3 : i32 | |
| } | |
| func.return %1 : i32 | |
| } | |
| } | |
| after: | |
| builtin.module { | |
| func.func @test_false() -> i32 { | |
| %0 = arith.constant 2 : i32 | |
| func.return %0 : i32 | |
| } | |
| } | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from xdsl.context import Context | |
| from xdsl.dialects import scf, arith, func, builtin | |
| from xdsl.dialects.builtin import ModuleOp, i32, IndexType, IntegerAttr | |
| from xdsl.ir import Block, Region | |
| from xdsl.transforms.canonicalize import CanonicalizePass | |
| def get_ctx(): | |
| ctx = Context() | |
| ctx.load_dialect(scf.Scf) | |
| ctx.load_dialect(arith.Arith) | |
| ctx.load_dialect(func.Func) | |
| ctx.load_dialect(builtin.Builtin) | |
| return ctx | |
| def get_const_int(op): | |
| if not isinstance(op, arith.ConstantOp): | |
| return None | |
| val = op.value | |
| if isinstance(val, IntegerAttr): | |
| return val.value.data | |
| return None | |
| def test_scf_if_true(): | |
| ctx = get_ctx() | |
| true_val = arith.ConstantOp.from_int_and_width(1, 1) | |
| val1 = arith.ConstantOp.from_int_and_width(1, 32) | |
| true_region = Region(Block([val1, scf.YieldOp(val1)])) | |
| val2 = arith.ConstantOp.from_int_and_width(2, 32) | |
| false_region = Region(Block([val2, scf.YieldOp(val2)])) | |
| if_op = scf.IfOp(true_val, [i32], true_region, false_region) | |
| func_body = Region(Block([true_val, if_op, func.ReturnOp(if_op)])) | |
| func_op = func.FuncOp("test_true", ([], [i32]), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| assert get_const_int(ops[0]) == 1 | |
| assert isinstance(ops[1], func.ReturnOp) | |
| def test_scf_if_false(): | |
| ctx = get_ctx() | |
| false_val = arith.ConstantOp.from_int_and_width(0, 1) | |
| val1 = arith.ConstantOp.from_int_and_width(1, 32) | |
| true_region = Region(Block([val1, scf.YieldOp(val1)])) | |
| val2 = arith.ConstantOp.from_int_and_width(2, 32) | |
| false_region = Region(Block([val2, scf.YieldOp(val2)])) | |
| if_op = scf.IfOp(false_val, [i32], true_region, false_region) | |
| func_body = Region(Block([false_val, if_op, func.ReturnOp(if_op)])) | |
| func_op = func.FuncOp("test_false", ([], [i32]), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| assert get_const_int(ops[0]) == 2 | |
| assert isinstance(ops[1], func.ReturnOp) | |
| def test_scf_if_no_else_true(): | |
| ctx = get_ctx() | |
| true_val = arith.ConstantOp.from_int_and_width(1, 1) | |
| true_region = Region(Block([scf.YieldOp()])) | |
| if_op = scf.IfOp(true_val, [], true_region) | |
| func_body = Region(Block([true_val, if_op, func.ReturnOp()])) | |
| func_op = func.FuncOp("test_no_else_true", ([], []), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| assert len(ops) == 1 | |
| assert isinstance(ops[0], func.ReturnOp) | |
| def test_scf_if_no_else_false(): | |
| ctx = get_ctx() | |
| false_val = arith.ConstantOp.from_int_and_width(0, 1) | |
| true_region = Region(Block([scf.YieldOp()])) | |
| if_op = scf.IfOp(false_val, [], true_region) | |
| func_body = Region(Block([false_val, if_op, func.ReturnOp()])) | |
| func_op = func.FuncOp("test_no_else_false", ([], []), func_body) | |
| module = ModuleOp([func_op]) | |
| try: | |
| CanonicalizePass().apply(ctx, module) | |
| except ValueError as e: | |
| print(f"Caught expected ValueError in test_scf_if_no_else_false: {e}") | |
| return | |
| ops = list(func_op.body.block.ops) | |
| assert len(ops) == 1 | |
| assert isinstance(ops[0], func.ReturnOp) | |
| def test_scf_for_zero_iters(): | |
| ctx = get_ctx() | |
| lb = arith.ConstantOp.from_int_and_width(0, IndexType()) | |
| ub = arith.ConstantOp.from_int_and_width(0, IndexType()) | |
| step = arith.ConstantOp.from_int_and_width(1, IndexType()) | |
| loop_body = Region(Block(arg_types=[IndexType()])) | |
| loop_body.block.add_op(scf.YieldOp()) | |
| for_op = scf.ForOp(lb, ub, step, [], loop_body) | |
| func_body = Region(Block([lb, ub, step, for_op, func.ReturnOp()])) | |
| func_op = func.FuncOp("test_for_zero_iters", ([], []), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| assert not any(isinstance(op, scf.ForOp) for op in ops) | |
| def test_scf_for_negative_iters(): | |
| ctx = get_ctx() | |
| lb = arith.ConstantOp.from_int_and_width(10, IndexType()) | |
| ub = arith.ConstantOp.from_int_and_width(5, IndexType()) | |
| step = arith.ConstantOp.from_int_and_width(1, IndexType()) | |
| loop_body = Region(Block(arg_types=[IndexType()])) | |
| loop_body.block.add_op(scf.YieldOp()) | |
| for_op = scf.ForOp(lb, ub, step, [], loop_body) | |
| func_body = Region(Block([lb, ub, step, for_op, func.ReturnOp()])) | |
| func_op = func.FuncOp("test_for_negative_iters", ([], []), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| assert not any(isinstance(op, scf.ForOp) for op in ops) | |
| def test_scf_for_one_iter(): | |
| ctx = get_ctx() | |
| lb = arith.ConstantOp.from_int_and_width(0, IndexType()) | |
| ub = arith.ConstantOp.from_int_and_width(1, IndexType()) | |
| step = arith.ConstantOp.from_int_and_width(1, IndexType()) | |
| init_val = arith.ConstantOp.from_int_and_width(0, 32) | |
| loop_body = Region(Block(arg_types=[IndexType(), i32])) | |
| val = arith.ConstantOp.from_int_and_width(42, 32) | |
| loop_body.block.add_op(val) | |
| loop_body.block.add_op(scf.YieldOp(val)) | |
| for_op = scf.ForOp(lb, ub, step, [init_val], loop_body) | |
| func_body = Region(Block([lb, ub, step, init_val, for_op, func.ReturnOp(for_op)])) | |
| func_op = func.FuncOp("test_for_one_iter", ([], [i32]), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| assert any(get_const_int(op) == 42 for op in ops) | |
| assert not any(isinstance(op, scf.ForOp) for op in ops) | |
| def test_scf_execute_region(): | |
| ctx = get_ctx() | |
| reg = Region(Block()) | |
| val = arith.ConstantOp.from_int_and_width(123, 32) | |
| reg.block.add_op(val) | |
| reg.block.add_op(scf.YieldOp(val)) | |
| exec_op = scf.ExecuteRegionOp([i32], reg) | |
| func_body = Region(Block([exec_op, func.ReturnOp(exec_op)])) | |
| func_op = func.FuncOp("test_execute_region", ([], [i32]), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| assert any(get_const_int(op) == 123 for op in ops) | |
| assert not any(isinstance(op, scf.ExecuteRegionOp) for op in ops) | |
| def test_scf_for_rehoist(): | |
| ctx = get_ctx() | |
| lb = arith.ConstantOp.from_int_and_width(0, IndexType()) | |
| ub = arith.ConstantOp.from_int_and_width(10, IndexType()) | |
| step = arith.ConstantOp.from_int_and_width(1, IndexType()) | |
| init_val = arith.ConstantOp.from_int_and_width(0, 32) | |
| loop_body = Region(Block(arg_types=[IndexType(), i32])) | |
| inner_const = arith.ConstantOp.from_int_and_width(7, 32) | |
| loop_body.block.add_op(inner_const) | |
| loop_body.block.add_op(scf.YieldOp(inner_const)) | |
| for_op = scf.ForOp(lb, ub, step, [init_val], loop_body) | |
| func_body = Region(Block([lb, ub, step, init_val, for_op, func.ReturnOp(for_op)])) | |
| func_op = func.FuncOp("test_for_rehoist", ([], [i32]), func_body) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| ops = list(func_op.body.block.ops) | |
| for_op_idx = -1 | |
| const_7_idx = -1 | |
| for i, op in enumerate(ops): | |
| if isinstance(op, scf.ForOp): | |
| for_op_idx = i | |
| if get_const_int(op) == 7: | |
| const_7_idx = i | |
| assert const_7_idx != -1 | |
| assert for_op_idx != -1 | |
| assert const_7_idx < for_op_idx | |
| def test_scf_if_unused(): | |
| ctx = get_ctx() | |
| # non-const condition to prevent IfPropagateConstantCondition | |
| # we can use a function argument | |
| func_op = func.FuncOp("test_unused", ([builtin.i1], []), Region(Block(arg_types=[builtin.i1]))) | |
| cond = func_op.body.block.args[0] | |
| true_region = Region(Block([scf.YieldOp()])) | |
| false_region = Region(Block([scf.YieldOp()])) | |
| if_op = scf.IfOp(cond, [], true_region, false_region) | |
| func_op.body.block.add_op(if_op) | |
| func_op.body.block.add_op(func.ReturnOp()) | |
| module = ModuleOp([func_op]) | |
| CanonicalizePass().apply(ctx, module) | |
| # if_op should be removed because it is Pure and its results are unused | |
| ops = list(func_op.body.block.ops) | |
| assert not any(isinstance(op, scf.IfOp) for op in ops) | |
| if __name__ == "__main__": | |
| test_scf_if_true() | |
| test_scf_if_false() | |
| test_scf_if_no_else_true() | |
| test_scf_if_no_else_false() | |
| test_scf_for_zero_iters() | |
| test_scf_for_negative_iters() | |
| test_scf_for_one_iter() | |
| test_scf_execute_region() | |
| test_scf_for_rehoist() | |
| test_scf_if_unused() | |
| print("all tests passed!") |
Author
uv run xdsl-opt --split-input-file -p canonicalize tests/filecheck/dialects/scf/canonicalize.mlir | uv run filecheck tests/filecheck/dialects/scf/canonicalize.mlir && echo "✓ all good yo"
uv run lit -v tests/filecheck/dialects/scf/canonicalize.mlir
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.