Last active
February 11, 2026 15:28
-
-
Save PragmaTwice/ced5fb2b50223e8f913f236b05274333 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # RUN: %PYTHON %s 2>&1 | FileCheck %s | |
| # REQUIRES: host-supports-jit | |
| from mlir.ir import * | |
| from mlir.dialects.ext import * | |
| from mlir.rewrite import * | |
| from mlir.passmanager import * | |
| from mlir.execution_engine import * | |
| from mlir.dialects import llvm, scf, func | |
| class BfDialect(Dialect, name="bf"): | |
| pass | |
| class NextOp(BfDialect.Operation, name="next"): | |
| in_: Operand[llvm.PointerType] | |
| out: Result[llvm.PointerType[()]] | |
| class PrevOp(BfDialect.Operation, name="prev"): | |
| in_: Operand[llvm.PointerType] | |
| out: Result[llvm.PointerType[()]] | |
| class IncOp(BfDialect.Operation, name="inc"): | |
| in_: Operand[llvm.PointerType] | |
| class DecOp(BfDialect.Operation, name="dec"): | |
| in_: Operand[llvm.PointerType] | |
| class InputOp(BfDialect.Operation, name="input"): | |
| in_: Operand[llvm.PointerType] | |
| class OutputOp(BfDialect.Operation, name="output"): | |
| in_: Operand[llvm.PointerType] | |
| class WhileOp(BfDialect.Operation, name="while"): | |
| in_: Operand[llvm.PointerType] | |
| out: Result[llvm.PointerType[()]] | |
| body: Region | |
| class YieldOp(BfDialect.Operation, name="yield", traits=[IsTerminatorTrait]): | |
| in_: Operand[llvm.PointerType] | |
| class MainOp(BfDialect.Operation, name="main"): | |
| body: Region | |
| def parse(code: str): | |
| module = Module.create() | |
| with InsertionPoint(module.body): | |
| main = MainOp() | |
| main.body.blocks.append() | |
| current_val = main.body.blocks[0].add_argument( | |
| llvm.PointerType.get(), Location.unknown() | |
| ) | |
| ip = InsertionPoint(main.body.blocks[0]) | |
| for c in code: | |
| with ip: | |
| if c == ">": | |
| current_val = NextOp(current_val).out | |
| elif c == "<": | |
| current_val = PrevOp(current_val).out | |
| elif c == "+": | |
| IncOp(current_val) | |
| elif c == "-": | |
| DecOp(current_val) | |
| elif c == ".": | |
| OutputOp(current_val) | |
| elif c == ",": | |
| InputOp(current_val) | |
| elif c == "[": | |
| loop = WhileOp(current_val) | |
| loop.body.blocks.append() | |
| current_val = loop.body.blocks[0].add_argument( | |
| llvm.PointerType.get(), Location.unknown() | |
| ) | |
| ip = InsertionPoint(loop.body.blocks[0]) | |
| elif c == "]": | |
| YieldOp(current_val) | |
| current_val = ip.block.owner.opview.out | |
| ip = InsertionPoint.after(current_val.owner) | |
| with ip: | |
| YieldOp(current_val) | |
| return module | |
| def convert_bf_to_llvm(op, pass_): | |
| patterns = RewritePatternSet() | |
| ptr = llvm.PointerType.get() | |
| i8 = IntegerType.get_signless(8) | |
| i32 = IntegerType.get_signless(32) | |
| def convert_next(op, rewriter, offset=1): | |
| with rewriter.ip: | |
| gep = llvm.GEPOp(ptr, op.in_, [], [offset], i8, []) | |
| rewriter.replace_op(op, gep) | |
| def convert_inc(op, rewriter, cst=1): | |
| with rewriter.ip: | |
| load = llvm.load(i8, op.in_) | |
| one = llvm.mlir_constant(IntegerAttr.get(i8, cst)) | |
| added = llvm.add(load, one, []) | |
| store = llvm.StoreOp(added, op.in_) | |
| rewriter.replace_op(op, store) | |
| def convert_main(op, rewriter): | |
| with rewriter.ip: | |
| fn = func.FuncOp("bf_main", FunctionType.get([ptr], [ptr])) | |
| op.body.blocks[0].append_to(fn.body) | |
| rewriter.replace_op(op, fn) | |
| def convert_yield(op, rewriter): | |
| with rewriter.ip: | |
| if isinstance(op.parent.opview, WhileOp): | |
| yield_ = scf.YieldOp([op.in_]) | |
| else: | |
| yield_ = func.ReturnOp([op.in_]) | |
| rewriter.replace_op(op, yield_) | |
| def convert_while(op, rewriter): | |
| with rewriter.ip: | |
| loop = scf.WhileOp([ptr], [op.in_]) | |
| loop.before.blocks.append() | |
| arg = loop.before.blocks[0].add_argument(ptr, Location.unknown()) | |
| with InsertionPoint(loop.before.blocks[0]): | |
| c = llvm.load(i8, arg) | |
| zero = llvm.mlir_constant(IntegerAttr.get(i8, 0)) | |
| cond = llvm.icmp(llvm.ICmpPredicate.ne, c, zero) | |
| scf.ConditionOp(cond, [arg]) | |
| op.body.blocks[0].append_to(loop.after) | |
| rewriter.replace_op(op, loop) | |
| def convert_output(op, rewriter): | |
| with rewriter.ip: | |
| val = llvm.load(i8, op.in_) | |
| call = func.CallOp([], "bf_output", [val]) | |
| rewriter.replace_op(op, call) | |
| def convert_input(op, rewriter): | |
| with rewriter.ip: | |
| call = func.call([i8], "bf_input", []) | |
| store = llvm.StoreOp(call, op.in_) | |
| rewriter.replace_op(op, store) | |
| patterns.add(NextOp, convert_next) | |
| patterns.add(PrevOp, lambda op, rewriter: convert_next(op, rewriter, offset=-1)) | |
| patterns.add(IncOp, convert_inc) | |
| patterns.add(DecOp, lambda op, rewriter: convert_inc(op, rewriter, cst=-1)) | |
| patterns.add(MainOp, convert_main) | |
| patterns.add(YieldOp, convert_yield) | |
| patterns.add(WhileOp, convert_while) | |
| patterns.add(OutputOp, convert_output) | |
| patterns.add(InputOp, convert_input) | |
| apply_patterns_and_fold_greedily(op, patterns.freeze()) | |
| with InsertionPoint(op.opview.body): | |
| func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private") | |
| func.FuncOp("getchar", FunctionType.get([], [i32]), visibility="private") | |
| output = func.FuncOp("bf_output", FunctionType.get([i8], [])) | |
| output.body.blocks.append() | |
| arg = output.body.blocks[0].add_argument(i8, Location.unknown()) | |
| with InsertionPoint(output.body.blocks[0]): | |
| sext = llvm.sext(i32, arg) | |
| func.call([i32], "putchar", [sext]) | |
| func.ReturnOp([]) | |
| input = func.FuncOp("bf_input", FunctionType.get([], [i8])) | |
| input.body.blocks.append() | |
| with InsertionPoint(input.body.blocks[0]): | |
| call = func.call([i32], "getchar", []) | |
| trunc = llvm.trunc(i8, call, []) | |
| func.ReturnOp([trunc]) | |
| init = func.FuncOp("bf_init", FunctionType.get([], [])) | |
| init.attributes["llvm.emit_c_interface"] = UnitAttr.get() | |
| init.body.blocks.append() | |
| with InsertionPoint(init.body.blocks[0]): | |
| c1024 = llvm.mlir_constant(IntegerAttr.get(i32, 1024)) | |
| zero = llvm.mlir_constant(IntegerAttr.get(i8, 0)) | |
| p = llvm.alloca(ptr, c1024, i8) | |
| llvm.intr_memset(p, zero, c1024, False) | |
| func.call([ptr], "bf_main", [p]) | |
| func.ReturnOp([]) | |
| def execute(code): | |
| module = parse(code) | |
| assert module.operation.verify() | |
| pm = PassManager() | |
| pm.add(convert_bf_to_llvm) | |
| pm.add("convert-scf-to-cf, convert-to-llvm") | |
| pm.run(module.operation) | |
| ee = ExecutionEngine(module) | |
| ee.lookup("bf_init")(0) | |
| if __name__ == "__main__": | |
| with Context(), Location.unknown(): | |
| BfDialect.load() | |
| # CHECK: Hello World! | |
| execute( | |
| "++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++." | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment