Skip to content

Instantly share code, notes, and snippets.

@tripplyons
Created February 11, 2026 18:11
Show Gist options
  • Select an option

  • Save tripplyons/295ed53887a797756a06150c366478da to your computer and use it in GitHub Desktop.

Select an option

Save tripplyons/295ed53887a797756a06150c366478da to your computer and use it in GitHub Desktop.
Issue with expected Gated DeltaNet prefill output for MLSys26 FlashInfer-Bench Contest
# run with `uv run reprex.py`
# /// script
# requires-python = ">=3.10"
# dependencies = ["numpy", "packaging", "safetensors", "torch"]
# ///
import json
import subprocess
from pathlib import Path
import torch
from safetensors.torch import load_file
DATASET_REPO = "https://huggingface.co/datasets/flashinfer-ai/mlsys26-contest"
UUID = "c4349ccd-2d92-471e-8c7e-8ae218ea77b1"
def ensure_dataset() -> Path:
target = Path.cwd() / "mlsys26-contest"
if not target.exists():
subprocess.run(["git", "clone", DATASET_REPO, str(target)], check=True)
subprocess.run(["git", "-C", str(target), "lfs", "pull"], check=True)
return target
root = ensure_dataset()
defn = json.loads(
(root / "definitions/gdn/gdn_prefill_qk4_v8_d128_k_last.json").read_text()
)
ns: dict = {}
exec(defn["reference"], ns)
run = ns["run"]
wl = next(
json.loads(x)["workload"]
for x in (
root / "workloads/gdn/gdn_prefill_qk4_v8_d128_k_last.jsonl"
).read_text().splitlines()
if UUID in x
)
device = "cuda" if torch.cuda.is_available() else "cpu"
T = wl["axes"]["total_seq_len"]
N = wl["axes"]["num_seqs"]
q = torch.randn((T, 4, 128), dtype=torch.bfloat16, device=device)
k = torch.randn((T, 4, 128), dtype=torch.bfloat16, device=device)
v = torch.randn((T, 8, 128), dtype=torch.bfloat16, device=device)
state = torch.randn((N, 8, 128, 128), dtype=torch.float32, device=device)
s = load_file(root / wl["inputs"]["A_log"]["path"].replace("./", ""))
A_log = s["A_log"].to(device=device, dtype=torch.float32)
a = s["a"].to(device=device, dtype=torch.bfloat16)
dt_bias = s["dt_bias"].to(device=device, dtype=torch.float32)
b = s["b"].to(device=device, dtype=torch.bfloat16)
cu_seqlens = s["cu_seqlens"].to(device=device, dtype=torch.int64)
scale = float(wl["inputs"]["scale"]["value"])
out, _ = run(q, k, v, state, A_log, a, dt_bias, b, cu_seqlens, scale)
stds = out.float().flatten(1).std(dim=1, unbiased=False)
print("token,std")
for i, x in enumerate(stds.tolist()):
print(f"{i},{x:.6e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment