Skip to content

Instantly share code, notes, and snippets.

@speedcell4
Created December 21, 2025 13:26
Show Gist options
  • Select an option

  • Save speedcell4/c5539e799f3cd154ed5f49e5264c8b37 to your computer and use it in GitHub Desktop.

Select an option

Save speedcell4/c5539e799f3cd154ed5f49e5264c8b37 to your computer and use it in GitHub Desktop.
from typing import Any
import torch
import triton
import triton.language as tl
from torch import Tensor
from torch.nn import functional as F
from torch.testing import assert_close
@triton.jit
def add_lse(m, l, mask, z):
z = tl.where(mask, z, -float("inf"))
m, mp = tl.maximum(tl.max(z, axis=-1), m), m # [x]
l = l * tl.exp(mp - m) + tl.sum(tl.exp(z - m[:, None]), axis=-1) # [x]
return m, l
@triton.jit
def softplus_fwd(x):
return tl.maximum(x, 0.) + tl.log(1. + tl.exp(-tl.abs(x)))
@triton.jit
def softplus_bwd(x):
return 1. - tl.exp(-x)
@triton.autotune(
configs=[
triton.Config(dict(BLOCK_X=BLOCK_X, BLOCK_Y=BLOCK_Y), num_stages=num_stages)
for BLOCK_X in [16]
for BLOCK_Y in [128]
for num_stages in [3]
],
key=[],
)
@triton.jit
def gather_lse_fwd_kernel(
x_ptr: tl.tensor, x_s0: int, x_s1: int, x_s2: int,
y_ptr: tl.tensor, y_s0: int,
a_ptr: tl.tensor, a_s0: int,
b_ptr: tl.tensor, b_s0: int,
n_ptr: tl.tensor, n_s0: int,
p_ptr: tl.tensor, p_s0: int,
o_ptr: tl.tensor, o_s0: int,
X: int, Y: int, K: int, scale: float, use_count: tl.constexpr,
BLOCK_X: tl.constexpr,
BLOCK_Y: tl.constexpr,
BLOCK_K: tl.constexpr,
):
x_block = tl.make_block_ptr(
base=x_ptr,
shape=(X, K, 2),
strides=(x_s0, x_s1, x_s2),
offsets=(BLOCK_X * tl.program_id(0), 0, 0),
block_shape=(BLOCK_X, BLOCK_K, 2),
order=(2, 1, 0),
)
y_block = tl.make_block_ptr(
base=y_ptr,
shape=(Y,),
strides=(y_s0,),
offsets=(0,),
block_shape=(BLOCK_Y,),
order=(0,),
)
a_block = tl.make_block_ptr(
base=a_ptr,
shape=(X,),
strides=(a_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,),
)
b_block = tl.make_block_ptr(
base=b_ptr,
shape=(Y,),
strides=(b_s0,),
offsets=(0,),
block_shape=(BLOCK_Y,),
order=(0,),
)
n_block = tl.make_block_ptr(
base=n_ptr,
shape=(X,),
strides=(n_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,)
)
p_block = tl.make_block_ptr(
base=p_ptr,
shape=(X,),
strides=(p_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,)
)
o_block = tl.make_block_ptr(
base=o_ptr,
shape=(X,),
strides=(o_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,),
)
x = tl.load(x_block, boundary_check=(0, 1, 2), padding_option="zero") # [x, k, 2]
x = tl.reshape(x, (BLOCK_X, BLOCK_K * 2)) * scale
a = tl.load(a_block, boundary_check=(0,)) # [x]
bit = tl.arange(0, BLOCK_K).to(dtype=tl.int32)
xor = tl.arange(0, 2).to(dtype=tl.int32) ^ 1
xor = (bit[:, None] < K) & xor[None, :]
p_max = tl.full((BLOCK_X,), dtype=tl.float32, value=-scale * K)
n_max = tl.full((BLOCK_X,), dtype=tl.float32, value=-scale * K)
n = tl.zeros((BLOCK_X,), dtype=tl.float32)
p = tl.zeros((BLOCK_X,), dtype=tl.float32)
if use_count:
c = tl.zeros((BLOCK_X,), dtype=tl.float32)
xi = tl.arange(0, BLOCK_X)
yi = tl.arange(0, BLOCK_Y)
for _ in tl.range(tl.cdiv(Y, BLOCK_Y)):
y = tl.load(y_block, boundary_check=(0,), padding_option="zero").to(dtype=tl.int32)
y = ((y[None, None, :] >> bit[:, None, None]) & 1) ^ xor[:, :, None] # [k, 2, y]
y = tl.reshape(y, (BLOCK_K * 2, BLOCK_Y)).to(dtype=x.dtype)
b = tl.load(b_block, boundary_check=(0,)) # [y]
z = tl.dot(x, y, input_precision="ieee") # [x, y]
mask = (xi[:, None] < X) & (yi[None, :] < Y)
n_mask = mask & (a[:, None] != b[None, :])
p_mask = mask & (a[:, None] == b[None, :])
n_max, n = add_lse(n_max, n, n_mask, z)
p_max, p = add_lse(p_max, p, p_mask, -z)
if use_count:
c += tl.sum(p_mask.to(dtype=tl.float32), axis=-1)
y_block = tl.advance(y_block, (BLOCK_Y,))
b_block = tl.advance(b_block, (BLOCK_Y,))
yi += BLOCK_Y
n = tl.log(n) + n_max
p = tl.log(p) + p_max
tl.store(n_block, n.to(dtype=n_ptr.dtype.element_ty), boundary_check=(0,))
tl.store(p_block, p.to(dtype=p_ptr.dtype.element_ty), boundary_check=(0,))
if use_count:
tl.store(o_block, softplus_fwd(n + p - tl.log(c)).to(dtype=o_ptr.dtype.element_ty), boundary_check=(0,))
else:
tl.store(o_block, softplus_fwd(n + p).to(dtype=o_ptr.dtype.element_ty), boundary_check=(0,))
def gather_lse_fwd(x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float, use_count: bool):
n = torch.empty_like(x[:, 0, 0])
p = torch.empty_like(x[:, 0, 0])
o = torch.empty_like(x[:, 0, 0])
def grid(meta):
return (
triton.cdiv(x.size(0), meta['BLOCK_X']),
)
gather_lse_fwd_kernel[grid](
x, *x.stride(),
y, *y.stride(),
a, *a.stride(),
b, *b.stride(),
n, *n.stride(),
p, *p.stride(),
o, *o.stride(),
x.size(0),
y.size(0),
x.size(1),
scale, use_count,
BLOCK_K=triton.next_power_of_2(x.size(1)),
)
return n, p, o
@triton.autotune(
configs=[
triton.Config(dict(BLOCK_X=BLOCK_X, BLOCK_Y=BLOCK_Y), num_stages=num_stages)
for BLOCK_X in [16]
for BLOCK_Y in [128]
for num_stages in [3]
],
key=[],
)
@triton.jit
def gather_lse_bwd_kernel(
x_ptr: tl.tensor, x_s0: int, x_s1: int, x_s2: int,
d_ptr: tl.tensor, d_s0: int, d_s1: int, d_s2: int,
y_ptr: tl.tensor, y_s0: int,
a_ptr: tl.tensor, a_s0: int,
b_ptr: tl.tensor, b_s0: int,
n_ptr: tl.tensor, n_s0: int,
p_ptr: tl.tensor, p_s0: int,
o_ptr: tl.tensor, o_s0: int,
g_ptr: tl.tensor, g_s0: int,
X: int, Y: int, K: int, scale: float,
BLOCK_X: tl.constexpr,
BLOCK_Y: tl.constexpr,
BLOCK_K: tl.constexpr,
):
x_block = tl.make_block_ptr(
base=x_ptr,
shape=(X, K, 2),
strides=(x_s0, x_s1, x_s2),
offsets=(BLOCK_X * tl.program_id(0), 0, 0),
block_shape=(BLOCK_X, BLOCK_K, 2),
order=(2, 1, 0),
)
d_block = tl.make_block_ptr(
base=d_ptr,
shape=(X, K, 2),
strides=(d_s0, d_s1, d_s2),
offsets=(BLOCK_X * tl.program_id(0), 0, 0),
block_shape=(BLOCK_X, BLOCK_K, 2),
order=(2, 1, 0),
)
y_block = tl.make_block_ptr(
base=y_ptr,
shape=(Y,),
strides=(y_s0,),
offsets=(0,),
block_shape=(BLOCK_Y,),
order=(0,),
)
a_block = tl.make_block_ptr(
base=a_ptr,
shape=(X,),
strides=(a_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,),
)
b_block = tl.make_block_ptr(
base=b_ptr,
shape=(Y,),
strides=(b_s0,),
offsets=(0,),
block_shape=(BLOCK_Y,),
order=(0,),
)
n_block = tl.make_block_ptr(
base=n_ptr,
shape=(X,),
strides=(n_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,)
)
p_block = tl.make_block_ptr(
base=p_ptr,
shape=(X,),
strides=(p_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,)
)
o_block = tl.make_block_ptr(
base=o_ptr,
shape=(X,),
strides=(o_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,),
)
g_block = tl.make_block_ptr(
base=g_ptr,
shape=(X,),
strides=(g_s0,),
offsets=(BLOCK_X * tl.program_id(0),),
block_shape=(BLOCK_X,),
order=(0,)
)
x = tl.load(x_block, boundary_check=(0, 1, 2), padding_option="zero") # [x, k, 2]
x = tl.reshape(x, (BLOCK_X, BLOCK_K * 2))
a = tl.load(a_block, boundary_check=(0,)) # [x]
bit = tl.arange(0, BLOCK_K).to(dtype=tl.int32)
xor = tl.arange(0, 2).to(dtype=tl.int32) ^ 1
xor = (bit[:, None] < K) & xor[None, :]
n = tl.load(n_block, boundary_check=(0,), padding_option="zero") # [x]
p = tl.load(p_block, boundary_check=(0,), padding_option="zero") # [x]
g = tl.load(g_block, boundary_check=(0,), padding_option="zero") # [x]
o = tl.load(o_block, boundary_check=(0,), padding_option="zero") # [x]
g = g * softplus_bwd(o)
d = tl.zeros((BLOCK_X, BLOCK_K * 2), dtype=tl.float32)
xi = tl.arange(0, BLOCK_X)
yi = tl.arange(0, BLOCK_Y)
for _ in tl.range(tl.cdiv(Y, BLOCK_Y)):
y = tl.load(y_block, boundary_check=(0,), padding_option="zero").to(dtype=tl.int32)
y = ((y[None, None, :] >> bit[:, None, None]) & 1) ^ xor[:, :, None] # [k, 2, y]
y = tl.reshape(y, (BLOCK_K * 2, BLOCK_Y)).to(dtype=x.dtype) * scale
b = tl.load(b_block, boundary_check=(0,)) # [y]
mask = (xi[:, None] < X) & (yi[None, :] < Y)
z = tl.dot(x, y, input_precision="ieee") # [x, y]
z1 = tl.where(mask & (a[:, None] != b[None, :]), z, -float("inf"))
z2 = tl.where(mask & (a[:, None] == b[None, :]), -z, -float("inf"))
d = tl.dot(g[:, None] * tl.exp(z1 - n[:, None]), tl.trans(y), acc=d, input_precision="ieee") # [x, k * 2]
d = tl.dot(g[:, None] * tl.exp(z2 - p[:, None]), -tl.trans(y), acc=d, input_precision="ieee") # [x, k * 2]
y_block = tl.advance(y_block, (BLOCK_Y,))
b_block = tl.advance(b_block, (BLOCK_Y,))
yi += BLOCK_Y
d = tl.reshape(d, (BLOCK_X, BLOCK_K, 2))
tl.store(d_block, d.to(dtype=d_ptr.dtype.element_ty), boundary_check=(0, 1, 2))
def gather_lse_bwd(
x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float,
n: Tensor, p: Tensor, o: Tensor, g: Tensor):
d = torch.empty_like(x)
def grid(meta):
return (
triton.cdiv(x.size(0), meta['BLOCK_X']),
)
gather_lse_bwd_kernel[grid](
x, *x.stride(),
d, *d.stride(),
y, *y.stride(),
a, *a.stride(),
b, *b.stride(),
n, *n.stride(),
p, *p.stride(),
o, *o.stride(),
g, *g.stride(),
x.size(0),
y.size(0),
x.size(1),
scale,
BLOCK_K=triton.next_power_of_2(x.size(1)),
)
return d
class GatherLse(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float) -> Tensor:
n, p, o = gather_lse_fwd(x, y, a, b, scale, True)
ctx.save_for_backward(x, y, a, b, n, p, o)
ctx.scale = scale
return o
@staticmethod
def backward(ctx: Any, g: Tensor):
x, y, a, b, n, p, o = ctx.saved_tensors
return gather_lse_bwd(x, y, a, b, ctx.scale, n, p, o, g), None, None, None, None
def gather_lse(x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float) -> Tensor:
return GatherLse.apply(x, y, a, b, scale)
def expected_gather_lse(x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float) -> Tensor:
index = torch.arange(x.size(1), device=x.device) # [k]
y = (y[:, None] >> index) & 1 # [y, k]
s = torch.gather(
input=x[:, None, :, :].expand((-1, y.size(0), -1, -1)),
index=y[None, :, :, None].expand((x.size(0), -1, -1, -1)),
dim=-1,
)
s = s.sum(dim=[-1, -2]).mul(scale)
n = torch.masked_fill(s, a[:, None] == b[None, :], -float("inf"))
p = torch.masked_fill(-s, a[:, None] != b[None, :], -float("inf"))
c = (a[:, None] == b[None, :]).float().sum(dim=-1)
return F.softplus(n.logsumexp(dim=-1) + p.logsumexp(dim=-1) - c.log())
if __name__ == '__main__':
X = 1000
Y = 20000
K = 31
V = 1024
scale = 1
x = torch.rand((X, K, 2), requires_grad=True, device='cuda')
y = torch.randint(1 << K, (Y,), device='cuda')
a = torch.randint(V, (X,), device='cuda')
b = torch.cat([a, torch.randint(V, (Y - X,), device='cuda')], dim=0)
actual = gather_lse(x, y, a, b, scale)
expected = expected_gather_lse(x, y, a, b, scale)
assert_close(actual, expected)
grad = torch.rand_like(actual)
actual_grad, = torch.autograd.grad(actual, x, grad)
expected_grad, = torch.autograd.grad(expected, x, grad)
assert_close(actual_grad, expected_grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment