Skip to content

Instantly share code, notes, and snippets.

@chrislloyd
Created January 29, 2026 04:05
Show Gist options
  • Select an option

  • Save chrislloyd/ee29626ead44c96e4e2cb2c50e42cba8 to your computer and use it in GitHub Desktop.

Select an option

Save chrislloyd/ee29626ead44c96e4e2cb2c50e42cba8 to your computer and use it in GitHub Desktop.
# BF interpreter with pointer propagation
# Includes concepts from (c) 2011 Andrew Brown
# Major rewrite (c) 2024 Corbin Simpson
#
# Replaced the BF mixin / makePeephole / AbstractDomain architecture with
# pointer propagation (overlap + prop_concat). Straight-line BF code merges
# into Prop(adj, off, diffs) tuples during parsing. Loop idiom recognition
# (Zero, ZeroScatter, Scan) works directly on the propagation diffs.
#
# JIT work: r_uint on tape indices eliminates negative-index guard_false
# checks from traces (8341 -> 3497 guards). promote(len(tape)) makes the
# bounds check in Loop.runOn a single constant comparison. Compiled and
# benchmarked via RPython in Docker (pypy:2-slim + PyPy 7.3.20 source).
#
# mandel.b runtime (hyperfine, 20 runs, RPython JIT, aarch64-linux):
# original: 1.196s +/- 0.007s
# this: 0.858s +/- 0.004s (1.39x faster)
#
# mandel.b code generation:
# original: 995us +/- 167us
# this: 630us +/- 79us (1.58x faster)
import os
import sys
from rpython.jit.codewriter.policy import JitPolicy
from rpython.rlib.jit import JitDriver, unroll_safe, promote
from rpython.rlib.rarithmetic import r_uint
# https://esolangs.org/wiki/Algebraic_Brainfuck#Pointer_propagation
def overlap(ds1, o, ds2):
n = max(len(ds1), o + len(ds2))
rv = [0] * n
for i in range(len(ds1)):
rv[i] += ds1[i]
for i in range(len(ds2)):
rv[o + i] += ds2[i]
return rv
def prop_concat(a1, o1, ds1, a2, o2, ds2):
if not ds2:
return a1 + a2, o1, ds1
if not ds1:
return a1 + a2, a1 + o2, ds2
if o1 < a1 + o2:
return a1 + a2, o1, overlap(ds1, a1 + o2 - o1, ds2)
return a1 + a2, a1 + o2, overlap(ds2, o1 - a1 - o2, ds1)
def _bf(c, n):
if n > 0: return c * n
if n < 0: return {'+': '-', '>': '<'}[c] * -n
return ''
# --
jitdriver = JitDriver(greens=['op'], reds=['position', 'tape'])
class Op(object):
_immutable_ = True
class _Input(Op):
_immutable_ = True
def runOn(self, tape, position):
tape[r_uint(position)] = ord(os.read(0, 1)[0])
return position
Input = _Input()
class _Output(Op):
_immutable_ = True
def runOn(self, tape, position):
os.write(1, chr(tape[r_uint(position)]))
return position
Output = _Output()
class Prop(Op):
_immutable_ = True
_immutable_fields_ = "adj", "off", "diffs[*]"
def __init__(self, adj, off, diffs):
self.adj = adj
self.off = off
self.diffs = diffs
@unroll_safe
def runOn(self, tape, position):
p = position + self.off
for d in self.diffs:
tape[r_uint(p)] += d
p += 1
return position + self.adj
class _Zero(Op):
_immutable_ = True
def runOn(self, tape, position):
tape[r_uint(position)] = 0
return position
Zero = _Zero()
class ZeroScatter(Op):
_immutable_ = True
_immutable_fields_ = "offsets[*]", "scales[*]"
def __init__(self, offsets, scales):
self.offsets = offsets
self.scales = scales
@unroll_safe
def runOn(self, tape, position):
v = tape[r_uint(position)]
for i in range(len(self.offsets)):
tape[r_uint(position + self.offsets[i])] += v * self.scales[i]
tape[r_uint(position)] = 0
return position
class Scan(Op):
_immutable_ = True
_immutable_fields_ = "stride",
def __init__(self, stride):
self.stride = stride
def runOn(self, tape, position):
n = len(tape)
while position >= 0 and position < n and tape[r_uint(position)]:
position += self.stride
return position
class Loop(Op):
_immutable_ = True
_immutable_fields_ = "op",
def __init__(self, op):
self.op = op
def runOn(self, tape, position):
op = self.op
while position >= 0 and position < len(tape) and tape[r_uint(position)]:
jitdriver.jit_merge_point(op=op, position=position, tape=tape)
promote(len(tape))
position = op.runOn(tape, position)
return position
class Seq(Op):
_immutable_ = True
_immutable_fields_ = "ops[*]",
def __init__(self, ops):
self.ops = ops
@unroll_safe
def runOn(self, tape, position):
for op in self.ops:
position = op.runOn(tape, position)
return position
# --
def _recognize_prop_loop(p):
assert isinstance(p, Prop)
if p.adj != 0:
return Scan(p.adj) if not p.diffs else Loop(p)
ci = -p.off
if ci < 0 or ci >= len(p.diffs):
return Loop(p)
cv = p.diffs[ci]
if cv not in (1, -1):
return Loop(p)
offsets, scales = [], []
for i in range(len(p.diffs)):
if i != ci and p.diffs[i] != 0:
offsets.append(p.off + i)
scales.append(p.diffs[i])
if not offsets:
return Zero
if cv != -1:
return Loop(p)
return ZeroScatter(offsets[:], scales[:])
def _wrap_loop(ops):
if len(ops) == 1: return Loop(ops[0])
return Loop(Seq(ops[:]))
def flush(ops, adj, off, diffs):
while diffs and diffs[0] == 0:
off += 1
diffs = diffs[1:]
while diffs and diffs[-1] == 0:
diffs.pop()
if diffs or adj:
ops.append(Prop(adj, off, diffs[:]))
return 0, 0, []
_PROP = {'+': (0, 0, [1]), '-': (0, 0, [-1]), '>': (1, 0, []), '<': (-1, 0, [])}
def parse(s):
stk = [[]]
adj, off, diffs = 0, 0, []
i = 0
while i < len(s) and s[i] == '[':
depth = 1
i += 1
while i < len(s) and depth:
if s[i] == '[': depth += 1
elif s[i] == ']': depth -= 1
i += 1
while i < len(s):
c = s[i]
if c in _PROP:
a2, o2, d2 = _PROP[c]
adj, off, diffs = prop_concat(adj, off, diffs, a2, o2, d2)
elif c == '.' or c == ',':
adj, off, diffs = flush(stk[-1], adj, off, diffs)
stk[-1].append(Output if c == '.' else Input)
elif c == '[':
adj, off, diffs = flush(stk[-1], adj, off, diffs)
stk.append([])
elif c == ']':
adj, off, diffs = flush(stk[-1], adj, off, diffs)
body = stk.pop()
if len(body) == 1 and isinstance(body[0], Prop):
stk[-1].append(_recognize_prop_loop(body[0]))
else:
stk[-1].append(_wrap_loop(body))
i += 1
flush(stk[-1], adj, off, diffs)
ops = stk[0]
if not ops: return Prop(0, 0, [])
if len(ops) == 1: return ops[0]
return Seq(ops[:])
# --
def op_to_str(op):
if isinstance(op, Prop):
if not op.diffs:
return _bf('>', op.adj)
parts = [_bf('>', op.off)]
for i in range(len(op.diffs)):
parts.append(_bf('+', op.diffs[i]))
if i < len(op.diffs) - 1:
parts.append('>')
parts.append(_bf('>', op.adj - op.off - len(op.diffs) + 1))
return ''.join(parts)
if isinstance(op, _Zero): return '[-]'
if isinstance(op, ZeroScatter):
parts, prev = ['[-'], 0
for i in range(len(op.offsets)):
parts.append(_bf('>', op.offsets[i] - prev) + _bf('+', op.scales[i]))
prev = op.offsets[i]
return ''.join(parts) + _bf('>', -prev) + ']'
if isinstance(op, Scan): return '[' + _bf('>', op.stride) + ']'
if isinstance(op, Loop): return '[' + op_to_str(op.op) + ']'
if isinstance(op, Seq): return ''.join([op_to_str(o) for o in op.ops])
if op is Input: return ','
if op is Output: return '.'
return ''
# --
def entryPoint(argv):
if len(argv) < 2 or "-h" in argv:
print "Usage: bf [-c <number of cells>] [-h] [-o] <program.bf>"
return 1
cells = 30000
if argv[1] == "-c":
cells = int(argv[2])
path = argv[3]
elif argv[1] == "-o":
path = argv[2]
else:
path = argv[1]
with open(path) as handle:
text = handle.read()
if "-o" in argv:
print op_to_str(parse(text))
return 0
tape = bytearray("\x00" * cells)
parse(text).runOn(tape, 0)
return 0
def target(*args): return entryPoint, None
def jitpolicy(driver): return JitPolicy()
if __name__ == "__main__":
sys.exit(entryPoint(sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment