Skip to content

Instantly share code, notes, and snippets.

@sueszli
Last active December 27, 2025 12:01
Show Gist options
  • Select an option

  • Save sueszli/35b6296b9e40d7c11af969ce18476be3 to your computer and use it in GitHub Desktop.

Select an option

Save sueszli/35b6296b9e40d7c11af969ce18476be3 to your computer and use it in GitHub Desktop.
canonicalization patterns
from xdsl.builder import Builder, InsertPoint
from xdsl.context import Context
from xdsl.dialects import riscv
from xdsl.dialects.builtin import ModuleOp
from xdsl.transforms.canonicalize import CanonicalizePass
from xdsl.printer import Printer
ctx = Context()
ctx.load_dialect(riscv.RISCV)
module = ModuleOp([])
builder = Builder(InsertPoint.at_end(module.body.blocks[0]))
print("------ OriImmediate")
li_op_1 = builder.insert(riscv.LiOp(2, rd=riscv.Registers.A0)) # a0 = 2
ori_op = builder.insert(riscv.OriOp(li_op_1, 1, rd=riscv.Registers.A0)) # a0 = a0 | 1
# = li a0, 3 (2 | 1 = 3)
# 0b10 | 0b01 = 0b11
builder.insert(riscv.SwOp(ori_op.results[0], ori_op.results[0], 0)) # keep alive, avoid dead code elimination
print("\nbefore:\n")
print(module)
CanonicalizePass().apply(ctx, module)
print("\nafter:\n")
print(module)
print("------ XoriImmediate")
module = ModuleOp([])
builder = Builder(InsertPoint.at_end(module.body.blocks[0]))
li_op_2 = builder.insert(riscv.LiOp(3, rd=riscv.Registers.A0)) # a0 = 3
xori_op = builder.insert(riscv.XoriOp(li_op_2, 1, rd=riscv.Registers.A0)) # a0 = a0 ^ 1
# = li a0, 2 (3 ^ 1 = 2)
# 0b11 ^ 0b01 = 0b10
builder.insert(riscv.SwOp(xori_op.results[0], xori_op.results[0], 0))
print("\nbefore:\n")
print(module)
CanonicalizePass().apply(ctx, module)
print("\nafter:\n")
print(module)
"""
------ OriImmediate
before:
builtin.module {
%0 = riscv.li 2 : !riscv.reg<a0>
%1 = riscv.ori %0, 1 : (!riscv.reg<a0>) -> !riscv.reg<a0>
riscv.sw %1, %1, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> ()
}
after:
builtin.module {
%0 = riscv.li 3 : !riscv.reg<a0>
riscv.sw %0, %0, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> ()
}
------ XoriImmediate
before:
builtin.module {
%0 = riscv.li 3 : !riscv.reg<a0>
%1 = riscv.xori %0, 1 : (!riscv.reg<a0>) -> !riscv.reg<a0>
riscv.sw %1, %1, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> ()
}
after:
builtin.module {
%0 = riscv.li 2 : !riscv.reg<a0>
riscv.sw %0, %0, 0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> ()
}
"""
from xdsl.context import Context
from xdsl.dialects import scf, arith, func, builtin
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import Block, Region
from xdsl.transforms.canonicalize import CanonicalizePass
ctx = Context()
ctx.load_dialect(scf.Scf)
ctx.load_dialect(arith.Arith)
ctx.load_dialect(func.Func)
ctx.load_dialect(builtin.Builtin)
print(f"{'-' * 20} scf.if canonicalization (true condition)")
true_val = arith.ConstantOp.from_int_and_width(1, 1)
val1 = arith.ConstantOp.from_int_and_width(1, 32) # true branch (just return 1)
true_block = Block([val1, scf.YieldOp(val1)])
true_region = Region(true_block)
val2 = arith.ConstantOp.from_int_and_width(2, 32) # false branch (just return 2)
false_block = Block([val2, scf.YieldOp(val2)])
false_region = Region(false_block)
if_op = scf.IfOp(true_val, [builtin.i32], true_region, false_region)
func_body = Region(Block([true_val, if_op, func.ReturnOp(if_op)]))
func_op = func.FuncOp("test_true", ([], [builtin.i32]), func_body)
module = ModuleOp([func_op])
print("\nbefore:\n")
print(module)
CanonicalizePass().apply(ctx, module)
print("\nafter:\n")
print(module)
print(f"{'-' * 20} scf.if canonicalization (false condition)")
false_val = arith.ConstantOp.from_int_and_width(0, 1)
val1_2 = arith.ConstantOp.from_int_and_width(1, 32) # true branch (just return 1)
true_block_2 = Block([val1_2, scf.YieldOp(val1_2)])
true_region_2 = Region(true_block_2)
val2_2 = arith.ConstantOp.from_int_and_width(2, 32) # false branch (just return 2)
false_block_2 = Block([val2_2, scf.YieldOp(val2_2)])
false_region_2 = Region(false_block_2)
if_op_2 = scf.IfOp(false_val, [builtin.i32], true_region_2, false_region_2)
func_body_2 = Region(Block([false_val, if_op_2, func.ReturnOp(if_op_2)]))
func_op_2 = func.FuncOp("test_false", ([], [builtin.i32]), func_body_2)
module = ModuleOp([func_op_2])
print("\nbefore:\n")
print(module)
CanonicalizePass().apply(ctx, module)
print("\nafter:\n")
print(module)
"""
-------------------- scf.if canonicalization (true condition)
before:
builtin.module {
func.func @test_true() -> i32 {
%0 = arith.constant true
%1 = scf.if %0 -> (i32) {
%2 = arith.constant 1 : i32
scf.yield %2 : i32
} else {
%3 = arith.constant 2 : i32
scf.yield %3 : i32
}
func.return %1 : i32
}
}
after:
builtin.module {
func.func @test_true() -> i32 {
%0 = arith.constant 1 : i32
func.return %0 : i32
}
}
-------------------- scf.if canonicalization (false condition)
before:
builtin.module {
func.func @test_false() -> i32 {
%0 = arith.constant false
%1 = scf.if %0 -> (i32) {
%2 = arith.constant 1 : i32
scf.yield %2 : i32
} else {
%3 = arith.constant 2 : i32
scf.yield %3 : i32
}
func.return %1 : i32
}
}
after:
builtin.module {
func.func @test_false() -> i32 {
%0 = arith.constant 2 : i32
func.return %0 : i32
}
}
"""
from xdsl.context import Context
from xdsl.dialects import scf, arith, func, builtin
from xdsl.dialects.builtin import ModuleOp, i32, IndexType, IntegerAttr
from xdsl.ir import Block, Region
from xdsl.transforms.canonicalize import CanonicalizePass
def get_ctx():
ctx = Context()
ctx.load_dialect(scf.Scf)
ctx.load_dialect(arith.Arith)
ctx.load_dialect(func.Func)
ctx.load_dialect(builtin.Builtin)
return ctx
def get_const_int(op):
if not isinstance(op, arith.ConstantOp):
return None
val = op.value
if isinstance(val, IntegerAttr):
return val.value.data
return None
def test_scf_if_true():
ctx = get_ctx()
true_val = arith.ConstantOp.from_int_and_width(1, 1)
val1 = arith.ConstantOp.from_int_and_width(1, 32)
true_region = Region(Block([val1, scf.YieldOp(val1)]))
val2 = arith.ConstantOp.from_int_and_width(2, 32)
false_region = Region(Block([val2, scf.YieldOp(val2)]))
if_op = scf.IfOp(true_val, [i32], true_region, false_region)
func_body = Region(Block([true_val, if_op, func.ReturnOp(if_op)]))
func_op = func.FuncOp("test_true", ([], [i32]), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
assert get_const_int(ops[0]) == 1
assert isinstance(ops[1], func.ReturnOp)
def test_scf_if_false():
ctx = get_ctx()
false_val = arith.ConstantOp.from_int_and_width(0, 1)
val1 = arith.ConstantOp.from_int_and_width(1, 32)
true_region = Region(Block([val1, scf.YieldOp(val1)]))
val2 = arith.ConstantOp.from_int_and_width(2, 32)
false_region = Region(Block([val2, scf.YieldOp(val2)]))
if_op = scf.IfOp(false_val, [i32], true_region, false_region)
func_body = Region(Block([false_val, if_op, func.ReturnOp(if_op)]))
func_op = func.FuncOp("test_false", ([], [i32]), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
assert get_const_int(ops[0]) == 2
assert isinstance(ops[1], func.ReturnOp)
def test_scf_if_no_else_true():
ctx = get_ctx()
true_val = arith.ConstantOp.from_int_and_width(1, 1)
true_region = Region(Block([scf.YieldOp()]))
if_op = scf.IfOp(true_val, [], true_region)
func_body = Region(Block([true_val, if_op, func.ReturnOp()]))
func_op = func.FuncOp("test_no_else_true", ([], []), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
assert len(ops) == 1
assert isinstance(ops[0], func.ReturnOp)
def test_scf_if_no_else_false():
ctx = get_ctx()
false_val = arith.ConstantOp.from_int_and_width(0, 1)
true_region = Region(Block([scf.YieldOp()]))
if_op = scf.IfOp(false_val, [], true_region)
func_body = Region(Block([false_val, if_op, func.ReturnOp()]))
func_op = func.FuncOp("test_no_else_false", ([], []), func_body)
module = ModuleOp([func_op])
try:
CanonicalizePass().apply(ctx, module)
except ValueError as e:
print(f"Caught expected ValueError in test_scf_if_no_else_false: {e}")
return
ops = list(func_op.body.block.ops)
assert len(ops) == 1
assert isinstance(ops[0], func.ReturnOp)
def test_scf_for_zero_iters():
ctx = get_ctx()
lb = arith.ConstantOp.from_int_and_width(0, IndexType())
ub = arith.ConstantOp.from_int_and_width(0, IndexType())
step = arith.ConstantOp.from_int_and_width(1, IndexType())
loop_body = Region(Block(arg_types=[IndexType()]))
loop_body.block.add_op(scf.YieldOp())
for_op = scf.ForOp(lb, ub, step, [], loop_body)
func_body = Region(Block([lb, ub, step, for_op, func.ReturnOp()]))
func_op = func.FuncOp("test_for_zero_iters", ([], []), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
assert not any(isinstance(op, scf.ForOp) for op in ops)
def test_scf_for_negative_iters():
ctx = get_ctx()
lb = arith.ConstantOp.from_int_and_width(10, IndexType())
ub = arith.ConstantOp.from_int_and_width(5, IndexType())
step = arith.ConstantOp.from_int_and_width(1, IndexType())
loop_body = Region(Block(arg_types=[IndexType()]))
loop_body.block.add_op(scf.YieldOp())
for_op = scf.ForOp(lb, ub, step, [], loop_body)
func_body = Region(Block([lb, ub, step, for_op, func.ReturnOp()]))
func_op = func.FuncOp("test_for_negative_iters", ([], []), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
assert not any(isinstance(op, scf.ForOp) for op in ops)
def test_scf_for_one_iter():
ctx = get_ctx()
lb = arith.ConstantOp.from_int_and_width(0, IndexType())
ub = arith.ConstantOp.from_int_and_width(1, IndexType())
step = arith.ConstantOp.from_int_and_width(1, IndexType())
init_val = arith.ConstantOp.from_int_and_width(0, 32)
loop_body = Region(Block(arg_types=[IndexType(), i32]))
val = arith.ConstantOp.from_int_and_width(42, 32)
loop_body.block.add_op(val)
loop_body.block.add_op(scf.YieldOp(val))
for_op = scf.ForOp(lb, ub, step, [init_val], loop_body)
func_body = Region(Block([lb, ub, step, init_val, for_op, func.ReturnOp(for_op)]))
func_op = func.FuncOp("test_for_one_iter", ([], [i32]), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
assert any(get_const_int(op) == 42 for op in ops)
assert not any(isinstance(op, scf.ForOp) for op in ops)
def test_scf_execute_region():
ctx = get_ctx()
reg = Region(Block())
val = arith.ConstantOp.from_int_and_width(123, 32)
reg.block.add_op(val)
reg.block.add_op(scf.YieldOp(val))
exec_op = scf.ExecuteRegionOp([i32], reg)
func_body = Region(Block([exec_op, func.ReturnOp(exec_op)]))
func_op = func.FuncOp("test_execute_region", ([], [i32]), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
assert any(get_const_int(op) == 123 for op in ops)
assert not any(isinstance(op, scf.ExecuteRegionOp) for op in ops)
def test_scf_for_rehoist():
ctx = get_ctx()
lb = arith.ConstantOp.from_int_and_width(0, IndexType())
ub = arith.ConstantOp.from_int_and_width(10, IndexType())
step = arith.ConstantOp.from_int_and_width(1, IndexType())
init_val = arith.ConstantOp.from_int_and_width(0, 32)
loop_body = Region(Block(arg_types=[IndexType(), i32]))
inner_const = arith.ConstantOp.from_int_and_width(7, 32)
loop_body.block.add_op(inner_const)
loop_body.block.add_op(scf.YieldOp(inner_const))
for_op = scf.ForOp(lb, ub, step, [init_val], loop_body)
func_body = Region(Block([lb, ub, step, init_val, for_op, func.ReturnOp(for_op)]))
func_op = func.FuncOp("test_for_rehoist", ([], [i32]), func_body)
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
ops = list(func_op.body.block.ops)
for_op_idx = -1
const_7_idx = -1
for i, op in enumerate(ops):
if isinstance(op, scf.ForOp):
for_op_idx = i
if get_const_int(op) == 7:
const_7_idx = i
assert const_7_idx != -1
assert for_op_idx != -1
assert const_7_idx < for_op_idx
def test_scf_if_unused():
ctx = get_ctx()
# non-const condition to prevent IfPropagateConstantCondition
# we can use a function argument
func_op = func.FuncOp("test_unused", ([builtin.i1], []), Region(Block(arg_types=[builtin.i1])))
cond = func_op.body.block.args[0]
true_region = Region(Block([scf.YieldOp()]))
false_region = Region(Block([scf.YieldOp()]))
if_op = scf.IfOp(cond, [], true_region, false_region)
func_op.body.block.add_op(if_op)
func_op.body.block.add_op(func.ReturnOp())
module = ModuleOp([func_op])
CanonicalizePass().apply(ctx, module)
# if_op should be removed because it is Pure and its results are unused
ops = list(func_op.body.block.ops)
assert not any(isinstance(op, scf.IfOp) for op in ops)
if __name__ == "__main__":
test_scf_if_true()
test_scf_if_false()
test_scf_if_no_else_true()
test_scf_if_no_else_false()
test_scf_for_zero_iters()
test_scf_for_negative_iters()
test_scf_for_one_iter()
test_scf_execute_region()
test_scf_for_rehoist()
test_scf_if_unused()
print("all tests passed!")
@sueszli
Copy link
Author

sueszli commented Dec 26, 2025

uv run xdsl-opt --split-input-file -p canonicalize tests/filecheck/backend/riscv/canonicalize.mlir | uv run filecheck tests/filecheck/backend/riscv/canonicalize.mlir && echo "✓ all good yo"
uv run lit -v ./tests/filecheck/dialects/riscv_cf/canonicalize.mlir          

@sueszli
Copy link
Author

sueszli commented Dec 26, 2025

uv run xdsl-opt --split-input-file -p canonicalize tests/filecheck/dialects/scf/canonicalize.mlir | uv run filecheck tests/filecheck/dialects/scf/canonicalize.mlir && echo "✓ all good yo"
uv run lit -v tests/filecheck/dialects/scf/canonicalize.mlir

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment