Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created December 28, 2025 01:06
Show Gist options
  • Select an option

  • Save Birch-san/e4839f9192fda16e3b916e0336c9d512 to your computer and use it in GitHub Desktop.

Select an option

Save Birch-san/e4839f9192fda16e3b916e0336c9d512 to your computer and use it in GitHub Desktop.
Modded-NanoGPT attention benchmark
# based on MIT-licensed code from https://github.com/KellerJordan/modded-nanogpt
from typing import Callable, Optional
import math
from pathlib import Path
from dataclasses import dataclass
from functools import partial
import torch
from torch import nn, Tensor, FloatTensor, IntTensor
import torch.nn.functional as F
from kernels import get_kernel, get_local_kernel
from triton.testing import do_bench
from torch.testing import assert_close
if use_fa3 := False:
get_attn_out: Callable[[Tensor|tuple[Tensor, ...]], Tensor]
if use_local_fa3_kernel := True:
flash_attn_interface = get_local_kernel(repo_path=Path('hf-kernels/flash-attention-3'), package_name='flash_attention_3').flash_attn_interface
get_attn_out = lambda out: out
elif use_community_fa3_kernel := False:
flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface
get_attn_out = lambda out: out
elif use_dist_fa3 := False:
import flash_attn_interface
def get_attn_out(out):
assert isinstance(out, tuple)
out, lse = out
return out
else:
raise ValueError("well you have to pick one")
def varlen_attn(
q: FloatTensor,
k: FloatTensor,
v: FloatTensor,
max_seqlen_q: int,
max_seqlen_k: int,
cum_seq_q: Optional[IntTensor] = None,
cum_seq_k: Optional[IntTensor] = None,
causal=False,
scale: Optional[float] = None,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
) -> FloatTensor:
assert (window_size_left is None) == (window_size_right is None)
window_size = None if window_size_left is None else (window_size_left, window_size_right)
out = flash_attn_interface.flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cum_seq_q,
cu_seqlens_k=cum_seq_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=causal,
softmax_scale=scale,
window_size=window_size,
)
return get_attn_out(out)
else:
print("[WARN] falling back to torch builtin private varlen attn API, which is probably FA2")
def varlen_attn(
q: FloatTensor,
k: FloatTensor,
v: FloatTensor,
max_seqlen_q: int,
max_seqlen_k: int,
cum_seq_q: Optional[IntTensor] = None,
cum_seq_k: Optional[IntTensor] = None,
causal=False,
scale: Optional[float] = None,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
) -> FloatTensor:
out, _, _, _, _ = torch.ops.aten._flash_attention_forward(
q, k, v,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=max_seqlen_q,
max_k=max_seqlen_k,
dropout_p=0.0,
is_causal=causal,
return_debug_mask=False,
scale=scale,
window_size_left=window_size_left,
window_size_right=window_size_right,
)
return out
def next_multiple_of_n(v: float | int, *, n: int):
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
@dataclass
class Hyperparameters:
train_bs_schedule: tuple = (8 * 2048 * 8, 16 * 2048 * 8, 24 * 2048 * 8)
train_max_seq_len: int = 128 * 16
val_batch_size: int = 4 * 64 * 1024 * 8
args = Hyperparameters()
rank = 0
world_size = 8
grad_accum_steps = 8 // world_size
max_seq_len=args.val_batch_size // (grad_accum_steps * world_size)
device = torch.device("cuda", 0)
torch.cuda.set_device(device)
@dataclass
class AttnArgs:
ve: torch.Tensor
sa_lambdas: torch.Tensor
seqlens: torch.Tensor
bm_size: int
cos: torch.Tensor
sin: torch.Tensor
attn_scale: float
key_shift: bool
def norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
class Yarn(nn.Module):
def __init__(self, head_dim, max_seq_len):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.reset()
def reset(self):
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device)
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)])
t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device)
theta = torch.outer(t, angular_freq)
self.cos = nn.Buffer(
theta.cos().to(torch.bfloat16), persistent=False
)
self.sin = nn.Buffer(
theta.sin().to(torch.bfloat16), persistent=False
)
self.angular_freq = angular_freq
# start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
self.attn_scale = 0.1
def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32):
rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi)
scaling_factor = old_window / new_window
interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1)
self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor)
t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device)
theta = torch.outer(t, self.angular_freq)
self.cos.copy_(theta.cos())
self.sin.copy_(theta.sin())
self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1
def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor):
assert cos.size(0) >= x_BTHD.size(-3)
cos, sin = (
cos[None, : x_BTHD.size(-3), None, :],
sin[None, : x_BTHD.size(-3), None, :],
)
x1, x2 = x_BTHD.chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3)
class CastedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = use_fp8
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
with torch.no_grad():
self.weight.zero_() # @Grad62304977 and others
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
class CausalSelfAttentionBase(nn.Module):
def __init__(self, dim: int, head_dim: int, num_heads: int):
super().__init__()
self.call_super_init = True
self.num_heads = num_heads
self.head_dim = head_dim
self.dim = dim
self.hdim = num_heads * head_dim
assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim"
std = self.dim ** -0.5
bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
# merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
# https://x.com/hi_tysam/status/1879699187107033311
# Simplified layout by @chrisjmccormick
self.qkvo_w = nn.Parameter(torch.empty(self.dim * 4, self.hdim))
# label all modules for explicit optimizer grouping
self.qkvo_w.label = 'attn'
with torch.no_grad():
self.qkvo_w[:self.dim * 3].uniform_(-bound, bound) # init QKV weights
self.qkvo_w[self.dim * 3:].zero_() # init O weights to zero
# sparse gated attention to enable context based no-op by @classiclarryd
self.attn_gate = CastedLinear(12, num_heads)
self.attn_gate.weight.label = 'attn_gate'
class CausalSelfAttentionOrig(CausalSelfAttentionBase):
def forward(self, x: Tensor, attn_args: AttnArgs):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "varlen sequences requires B == 1"
assert T % 16 == 0
# unpack attention args
cos, sin = attn_args.cos, attn_args.sin
ve, sa_lambdas, key_shift = attn_args.ve, attn_args.sa_lambdas, attn_args.key_shift
seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size
q, k, v = F.linear(x, sa_lambdas[0] * self.qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
q, k = norm(q), norm(k) # QK norm @Grad62304977
q, k = rotary(q, cos, sin), rotary(k, cos, sin)
if key_shift:
# shift keys forward for the stationary head dims. Enables 1-layer induction.
k[:, 1:, :, self.head_dim//4:self.head_dim//2] = k[:, :-1, :, self.head_dim//4:self.head_dim//2]
k[:, 1:, :, self.head_dim//4+self.head_dim//2:] = k[:, :-1, :, self.head_dim//4+self.head_dim//2:]
if ve is not None:
v = v + ve.view_as(v) # @ KoszarskyB & @Grad62304977
max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size))
# use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng
y: Tensor = varlen_attn(
q[0],
k[0],
v[0],
max_seqlen_q=max_len,
max_seqlen_k=max_len,
cum_seq_q=seqlens,
cum_seq_k=seqlens,
causal=True,
scale=attn_scale,
window_size_left=bm_size,
window_size_right=0,
)
y = y.view(B, T, self.num_heads, self.head_dim)
y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
y = F.linear(y, sa_lambdas[1] * self.qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg
return y
class CausalSelfAttentionNext(CausalSelfAttentionOrig): ...
head_dim=128
with torch.device('cuda'):
yarn = Yarn(head_dim, max_seq_len)
hp_dtype=torch.bfloat16
dim=768
# seqlens=torch.tensor((0, args.train_max_seq_len), dtype=torch.int32, device=device)
avg_seqlen=400 # median doc length is ~400
minibsz=args.train_bs_schedule[0]
pergpu_minibsz=minibsz//grad_accum_steps
microbsz = pergpu_minibsz // world_size
max_num_docs = next_multiple_of_n(microbsz // 300, n=128)
seqlens=torch.arange(0, avg_seqlen*max_num_docs, avg_seqlen, dtype=torch.int32, device=device).clamp_max_(microbsz)
ve = torch.randn((microbsz, dim), device=device, dtype=hp_dtype, requires_grad=True)
sa_lambdas = torch.tensor((.5, 1.), device=device, requires_grad=True)
short_bm=128
num_heads=6
with torch.device('meta'):
orig = CausalSelfAttentionOrig(
dim=dim,
head_dim=head_dim,
num_heads=num_heads,
)
next = CausalSelfAttentionNext(
dim=dim,
head_dim=head_dim,
num_heads=num_heads,
)
seed=42
gen=torch.Generator(device)
loss_fn = nn.MSELoss()
for attn in (orig, next):
attn.to_empty(device=device)
attn.qkvo_w.data.normal_(std=attn.dim**-.5, generator=gen.manual_seed(seed))
attn.attn_gate.weight.data.normal_(std=attn.attn_gate.in_features**-.5, generator=gen.manual_seed(seed))
orig.forward = torch.compile(orig.forward, dynamic=False, fullgraph=True)
input = torch.randn((1, microbsz, dim), device=device, dtype=hp_dtype, generator=gen.manual_seed(seed+1), requires_grad=True)
target = torch.randn((1, microbsz, dim), device=device, dtype=hp_dtype, generator=gen.manual_seed(seed+2))
def do_fwd(mod: CausalSelfAttentionBase):
attn_args = AttnArgs(
ve=ve.clone(),
sa_lambdas=sa_lambdas.clone(),
seqlens=seqlens,
bm_size=short_bm,
cos=yarn.cos,
sin=yarn.sin,
attn_scale=yarn.attn_scale,
# attn_scale=head_dim**-.5,
key_shift=False
)
out: Tensor = mod(input.clone(), attn_args=attn_args)
return out
def do_lossbwd(out: Tensor):
loss: Tensor = loss_fn(out, target)
loss.backward()
return loss
def do_fwdbwd(mod: CausalSelfAttentionBase):
out: Tensor = do_fwd(mod)
do_lossbwd(out)
if test_correctness := True:
out_orig = do_fwd(orig)
out_next = do_fwd(next)
assert_close(out_orig, out_next)
do_lossbwd(out_orig)
do_lossbwd(out_next)
assert_close(orig.qkvo_w.grad, next.qkvo_w.grad)
assert_close(orig.attn_gate.weight.grad, next.attn_gate.weight.grad)
if test_latency := True:
orig_ms: float = do_bench(partial(do_fwdbwd, mod=orig))
next_ms: float = do_bench(partial(do_fwdbwd, mod=next))
orig_its: float = 1000 / orig_ms
next_its: float = 1000 / next_ms
print(f"""
orig: {orig_its:.2f} it/s
next: {next_its:.2f} it/s
""")
pass
@Birch-san
Copy link
Author

launch.json for modded-nanogpt (including this module, bench_attn.py):

{
    "configurations": [
        {
            "name": "Train GPT (single-device)",
            "type": "debugpy",
            "request": "launch",
            "module": "train_gpt",
            "env": {
                "RANK": "0",
                "LOCAL_RANK": "0",
                "WORLD_SIZE": "8",
                "TORCH_COMPILE_DISABLE": "1",
            }
        },
        {
            "name": "Cached Fineweb 10B (900M toks)",
            "type": "debugpy",
            "request": "launch",
            "module": "data.cached_fineweb10B",
            "args": ["9"],
        },
        {
            "name": "Benchmark Attention",
            "type": "debugpy",
            "request": "launch",
            "module": "bench_attn",
            "justMyCode": false,
            "env": {
                "PYTORCH_ALLOC_CONF": "expandable_segments:True",
            }
        },
    ]
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment