Skip to content

Instantly share code, notes, and snippets.

@sklam
Created December 16, 2025 18:14
Show Gist options
  • Select an option

  • Save sklam/f141efe69bf86182ed185274b6648e93 to your computer and use it in GitHub Desktop.

Select an option

Save sklam/f141efe69bf86182ed185274b6648e93 to your computer and use it in GitHub Desktop.
egglog-python example. (ab)use PyObject to allow equivalences be generated from Python functions
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator
@dataclass(frozen=True)
class Value:
pass
@dataclass(frozen=True)
class Atom(Value):
uid: str
def __repr__(self):
return f"${self.uid}"
@dataclass(frozen=True)
class Const(Value):
value: int
@dataclass(frozen=True)
class BinOp(Value):
opname: str
lhs: Value
rhs: Value
def generate_alternatives(self) -> Iterator[BinOp]:
match self:
case BinOp("+", x, y) if x == y:
yield BinOp("*", x, Const(2))
case BinOp("*", x, Const(2)):
yield BinOp("+", x, x)
def __repr__(self):
return f"BinOp({self.opname}, {self.lhs}, {self.rhs})"
@dataclass(frozen=True)
class Alternatives:
items: tuple[Value, ...]
def get(self, i: int) -> Value:
return self.items[i]
def __len__(self) -> int:
return len(self.items)
"""
Tested With:
- egglog == 12
"""
from __future__ import annotations
from pprint import pprint
from egglog import (
Bool,
EGraph,
Expr,
PyObject,
String,
delete,
get_callable_args,
get_callable_fn,
i64,
i64Like,
method,
rewrite,
rule,
ruleset,
union,
)
from eggcall_mod import Alternatives, Atom, BinOp, Const, Value
x = Atom("x")
y = Atom("y")
xplusx = BinOp("+", x, x)
xtimes2 = BinOp("*", x, Const(2))
print(xplusx)
print(xtimes2)
print(list(xplusx.generate_alternatives()))
print(list(xtimes2.generate_alternatives()))
total = set([xplusx, xtimes2])
total |= set(xplusx.generate_alternatives())
total |= set(xtimes2.generate_alternatives())
print(total)
# -----------------------------------------------------------------------------
class EggValue(Expr):
def __init__(self, py: PyObject): ...
@method(unextractable=True)
@classmethod
def binop(
self, opname: String, lhs: EggValue, rhs: EggValue
) -> EggValue: ...
class EggAltList(Expr):
def __init__(self, orig: EggValue, lst: PyObject, nitems: i64): ...
def _expand(self, idx: i64Like) -> EggAltList: ...
# -----------------------------------------------------------------------------
@PyObject
def egg_generate_alternatives(value: Value) -> Alternatives:
if gen := getattr(value, "generate_alternatives", None):
return Alternatives(tuple(sorted(gen())))
return Alternatives(())
@PyObject
def egg_alternatives_get(equiv_set: Alternatives, i: int) -> Value:
return equiv_set.get(i)
@PyObject
def egg_alternatives_length(equiv_set: Alternatives) -> int:
return len(equiv_set)
@PyObject
def egg_binop_typecheck(binop) -> bool:
return isinstance(binop, BinOp)
@PyObject
def egg_atom_typecheck(binop) -> bool:
return isinstance(binop, Atom)
@PyObject
def egg_binop_get_lhs(binop: BinOp) -> Value:
return binop.lhs
@PyObject
def egg_binop_get_rhs(binop: BinOp) -> Value:
return binop.rhs
@PyObject
def egg_binop_get_opname(binop: BinOp) -> str:
return binop.opname
@PyObject
def egg_rebuild_binop(opname: str, lhs: Value, rhs: Value) -> Value:
return BinOp(opname, lhs, rhs)
@ruleset
def rules_generate_alternatives(
orig: EggValue,
pyobj: PyObject,
altlist: PyObject,
idx: i64,
nitems: i64,
):
yield rule(
orig == EggValue(pyobj),
altlist == egg_generate_alternatives(pyobj),
).then(
EggAltList(
orig, altlist, egg_alternatives_length(altlist).to_int()
)._expand(0)
)
yield rewrite(
EggAltList(orig, altlist, nitems)._expand(idx),
).to(
EggAltList(orig, altlist, nitems)._expand(idx + 1),
idx < nitems,
)
yield rule(
EggAltList(orig, altlist, nitems)._expand(idx),
idx < nitems,
).then(
union(orig).with_(
EggValue(egg_alternatives_get(altlist, PyObject.from_int(idx)))
),
delete(EggAltList(orig, altlist, nitems)._expand(idx)),
)
yield rule(
EggAltList(orig, altlist, nitems)._expand(idx),
idx >= nitems,
).then(
delete(EggAltList(orig, altlist, nitems)._expand(idx)),
)
@ruleset
def rules_reflection(
binop: EggValue,
pyobj: PyObject,
opname: String,
lhs_pyboj: PyObject,
rhs_pyobj: PyObject,
):
yield rule(
binop == EggValue(pyobj),
egg_binop_typecheck(pyobj).to_bool() == Bool(True),
).then(
_opname := egg_binop_get_opname(pyobj).to_string(),
_lhs := EggValue(egg_binop_get_lhs(pyobj)),
_rhs := EggValue(egg_binop_get_rhs(pyobj)),
union(binop).with_(EggValue.binop(_opname, _lhs, _rhs)),
)
yield rule(
binop
== EggValue.binop(opname, EggValue(lhs_pyboj), EggValue(rhs_pyobj)),
).then(
union(binop).with_(
EggValue(
egg_rebuild_binop(
PyObject.from_string(opname), lhs_pyboj, rhs_pyobj
)
)
)
)
# -----------------------------------------------------------------------------
schedule = (rules_generate_alternatives + rules_reflection).saturate()
egraph = EGraph()
# root = EggValue(xplusx)
x = EggValue(Atom("x"))
add = EggValue.binop("+", x, x)
root = EggValue.binop("+", add, add)
egraph.let("root", root)
egraph.run(schedule)
print(egraph)
# egraph.display()
def unwrap(eg_node):
match eg_node:
case PyObject(value):
return value
case _:
raise ValueError(type(eg_node))
# Extraction
outputs = set()
for out in egraph.extract_multiple(root, n=10):
print("---- extracted option :=", out)
fn = get_callable_fn(out)
args = get_callable_args(out)
match fn:
case EggValue:
[arg] = args
outputs.add(unwrap(arg))
for expr in outputs:
print("---")
pprint(expr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment