Created
December 16, 2025 18:14
-
-
Save sklam/f141efe69bf86182ed185274b6648e93 to your computer and use it in GitHub Desktop.
egglog-python example. (ab)use PyObject to allow equivalences be generated from Python functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Iterator | |
| @dataclass(frozen=True) | |
| class Value: | |
| pass | |
| @dataclass(frozen=True) | |
| class Atom(Value): | |
| uid: str | |
| def __repr__(self): | |
| return f"${self.uid}" | |
| @dataclass(frozen=True) | |
| class Const(Value): | |
| value: int | |
| @dataclass(frozen=True) | |
| class BinOp(Value): | |
| opname: str | |
| lhs: Value | |
| rhs: Value | |
| def generate_alternatives(self) -> Iterator[BinOp]: | |
| match self: | |
| case BinOp("+", x, y) if x == y: | |
| yield BinOp("*", x, Const(2)) | |
| case BinOp("*", x, Const(2)): | |
| yield BinOp("+", x, x) | |
| def __repr__(self): | |
| return f"BinOp({self.opname}, {self.lhs}, {self.rhs})" | |
| @dataclass(frozen=True) | |
| class Alternatives: | |
| items: tuple[Value, ...] | |
| def get(self, i: int) -> Value: | |
| return self.items[i] | |
| def __len__(self) -> int: | |
| return len(self.items) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Tested With: | |
| - egglog == 12 | |
| """ | |
| from __future__ import annotations | |
| from pprint import pprint | |
| from egglog import ( | |
| Bool, | |
| EGraph, | |
| Expr, | |
| PyObject, | |
| String, | |
| delete, | |
| get_callable_args, | |
| get_callable_fn, | |
| i64, | |
| i64Like, | |
| method, | |
| rewrite, | |
| rule, | |
| ruleset, | |
| union, | |
| ) | |
| from eggcall_mod import Alternatives, Atom, BinOp, Const, Value | |
| x = Atom("x") | |
| y = Atom("y") | |
| xplusx = BinOp("+", x, x) | |
| xtimes2 = BinOp("*", x, Const(2)) | |
| print(xplusx) | |
| print(xtimes2) | |
| print(list(xplusx.generate_alternatives())) | |
| print(list(xtimes2.generate_alternatives())) | |
| total = set([xplusx, xtimes2]) | |
| total |= set(xplusx.generate_alternatives()) | |
| total |= set(xtimes2.generate_alternatives()) | |
| print(total) | |
| # ----------------------------------------------------------------------------- | |
| class EggValue(Expr): | |
| def __init__(self, py: PyObject): ... | |
| @method(unextractable=True) | |
| @classmethod | |
| def binop( | |
| self, opname: String, lhs: EggValue, rhs: EggValue | |
| ) -> EggValue: ... | |
| class EggAltList(Expr): | |
| def __init__(self, orig: EggValue, lst: PyObject, nitems: i64): ... | |
| def _expand(self, idx: i64Like) -> EggAltList: ... | |
| # ----------------------------------------------------------------------------- | |
| @PyObject | |
| def egg_generate_alternatives(value: Value) -> Alternatives: | |
| if gen := getattr(value, "generate_alternatives", None): | |
| return Alternatives(tuple(sorted(gen()))) | |
| return Alternatives(()) | |
| @PyObject | |
| def egg_alternatives_get(equiv_set: Alternatives, i: int) -> Value: | |
| return equiv_set.get(i) | |
| @PyObject | |
| def egg_alternatives_length(equiv_set: Alternatives) -> int: | |
| return len(equiv_set) | |
| @PyObject | |
| def egg_binop_typecheck(binop) -> bool: | |
| return isinstance(binop, BinOp) | |
| @PyObject | |
| def egg_atom_typecheck(binop) -> bool: | |
| return isinstance(binop, Atom) | |
| @PyObject | |
| def egg_binop_get_lhs(binop: BinOp) -> Value: | |
| return binop.lhs | |
| @PyObject | |
| def egg_binop_get_rhs(binop: BinOp) -> Value: | |
| return binop.rhs | |
| @PyObject | |
| def egg_binop_get_opname(binop: BinOp) -> str: | |
| return binop.opname | |
| @PyObject | |
| def egg_rebuild_binop(opname: str, lhs: Value, rhs: Value) -> Value: | |
| return BinOp(opname, lhs, rhs) | |
| @ruleset | |
| def rules_generate_alternatives( | |
| orig: EggValue, | |
| pyobj: PyObject, | |
| altlist: PyObject, | |
| idx: i64, | |
| nitems: i64, | |
| ): | |
| yield rule( | |
| orig == EggValue(pyobj), | |
| altlist == egg_generate_alternatives(pyobj), | |
| ).then( | |
| EggAltList( | |
| orig, altlist, egg_alternatives_length(altlist).to_int() | |
| )._expand(0) | |
| ) | |
| yield rewrite( | |
| EggAltList(orig, altlist, nitems)._expand(idx), | |
| ).to( | |
| EggAltList(orig, altlist, nitems)._expand(idx + 1), | |
| idx < nitems, | |
| ) | |
| yield rule( | |
| EggAltList(orig, altlist, nitems)._expand(idx), | |
| idx < nitems, | |
| ).then( | |
| union(orig).with_( | |
| EggValue(egg_alternatives_get(altlist, PyObject.from_int(idx))) | |
| ), | |
| delete(EggAltList(orig, altlist, nitems)._expand(idx)), | |
| ) | |
| yield rule( | |
| EggAltList(orig, altlist, nitems)._expand(idx), | |
| idx >= nitems, | |
| ).then( | |
| delete(EggAltList(orig, altlist, nitems)._expand(idx)), | |
| ) | |
| @ruleset | |
| def rules_reflection( | |
| binop: EggValue, | |
| pyobj: PyObject, | |
| opname: String, | |
| lhs_pyboj: PyObject, | |
| rhs_pyobj: PyObject, | |
| ): | |
| yield rule( | |
| binop == EggValue(pyobj), | |
| egg_binop_typecheck(pyobj).to_bool() == Bool(True), | |
| ).then( | |
| _opname := egg_binop_get_opname(pyobj).to_string(), | |
| _lhs := EggValue(egg_binop_get_lhs(pyobj)), | |
| _rhs := EggValue(egg_binop_get_rhs(pyobj)), | |
| union(binop).with_(EggValue.binop(_opname, _lhs, _rhs)), | |
| ) | |
| yield rule( | |
| binop | |
| == EggValue.binop(opname, EggValue(lhs_pyboj), EggValue(rhs_pyobj)), | |
| ).then( | |
| union(binop).with_( | |
| EggValue( | |
| egg_rebuild_binop( | |
| PyObject.from_string(opname), lhs_pyboj, rhs_pyobj | |
| ) | |
| ) | |
| ) | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| schedule = (rules_generate_alternatives + rules_reflection).saturate() | |
| egraph = EGraph() | |
| # root = EggValue(xplusx) | |
| x = EggValue(Atom("x")) | |
| add = EggValue.binop("+", x, x) | |
| root = EggValue.binop("+", add, add) | |
| egraph.let("root", root) | |
| egraph.run(schedule) | |
| print(egraph) | |
| # egraph.display() | |
| def unwrap(eg_node): | |
| match eg_node: | |
| case PyObject(value): | |
| return value | |
| case _: | |
| raise ValueError(type(eg_node)) | |
| # Extraction | |
| outputs = set() | |
| for out in egraph.extract_multiple(root, n=10): | |
| print("---- extracted option :=", out) | |
| fn = get_callable_fn(out) | |
| args = get_callable_args(out) | |
| match fn: | |
| case EggValue: | |
| [arg] = args | |
| outputs.add(unwrap(arg)) | |
| for expr in outputs: | |
| print("---") | |
| pprint(expr) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment