Created
December 21, 2025 13:26
-
-
Save speedcell4/c5539e799f3cd154ed5f49e5264c8b37 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| from torch import Tensor | |
| from torch.nn import functional as F | |
| from torch.testing import assert_close | |
| @triton.jit | |
| def add_lse(m, l, mask, z): | |
| z = tl.where(mask, z, -float("inf")) | |
| m, mp = tl.maximum(tl.max(z, axis=-1), m), m # [x] | |
| l = l * tl.exp(mp - m) + tl.sum(tl.exp(z - m[:, None]), axis=-1) # [x] | |
| return m, l | |
| @triton.jit | |
| def softplus_fwd(x): | |
| return tl.maximum(x, 0.) + tl.log(1. + tl.exp(-tl.abs(x))) | |
| @triton.jit | |
| def softplus_bwd(x): | |
| return 1. - tl.exp(-x) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config(dict(BLOCK_X=BLOCK_X, BLOCK_Y=BLOCK_Y), num_stages=num_stages) | |
| for BLOCK_X in [16] | |
| for BLOCK_Y in [128] | |
| for num_stages in [3] | |
| ], | |
| key=[], | |
| ) | |
| @triton.jit | |
| def gather_lse_fwd_kernel( | |
| x_ptr: tl.tensor, x_s0: int, x_s1: int, x_s2: int, | |
| y_ptr: tl.tensor, y_s0: int, | |
| a_ptr: tl.tensor, a_s0: int, | |
| b_ptr: tl.tensor, b_s0: int, | |
| n_ptr: tl.tensor, n_s0: int, | |
| p_ptr: tl.tensor, p_s0: int, | |
| o_ptr: tl.tensor, o_s0: int, | |
| X: int, Y: int, K: int, scale: float, use_count: tl.constexpr, | |
| BLOCK_X: tl.constexpr, | |
| BLOCK_Y: tl.constexpr, | |
| BLOCK_K: tl.constexpr, | |
| ): | |
| x_block = tl.make_block_ptr( | |
| base=x_ptr, | |
| shape=(X, K, 2), | |
| strides=(x_s0, x_s1, x_s2), | |
| offsets=(BLOCK_X * tl.program_id(0), 0, 0), | |
| block_shape=(BLOCK_X, BLOCK_K, 2), | |
| order=(2, 1, 0), | |
| ) | |
| y_block = tl.make_block_ptr( | |
| base=y_ptr, | |
| shape=(Y,), | |
| strides=(y_s0,), | |
| offsets=(0,), | |
| block_shape=(BLOCK_Y,), | |
| order=(0,), | |
| ) | |
| a_block = tl.make_block_ptr( | |
| base=a_ptr, | |
| shape=(X,), | |
| strides=(a_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,), | |
| ) | |
| b_block = tl.make_block_ptr( | |
| base=b_ptr, | |
| shape=(Y,), | |
| strides=(b_s0,), | |
| offsets=(0,), | |
| block_shape=(BLOCK_Y,), | |
| order=(0,), | |
| ) | |
| n_block = tl.make_block_ptr( | |
| base=n_ptr, | |
| shape=(X,), | |
| strides=(n_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,) | |
| ) | |
| p_block = tl.make_block_ptr( | |
| base=p_ptr, | |
| shape=(X,), | |
| strides=(p_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,) | |
| ) | |
| o_block = tl.make_block_ptr( | |
| base=o_ptr, | |
| shape=(X,), | |
| strides=(o_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,), | |
| ) | |
| x = tl.load(x_block, boundary_check=(0, 1, 2), padding_option="zero") # [x, k, 2] | |
| x = tl.reshape(x, (BLOCK_X, BLOCK_K * 2)) * scale | |
| a = tl.load(a_block, boundary_check=(0,)) # [x] | |
| bit = tl.arange(0, BLOCK_K).to(dtype=tl.int32) | |
| xor = tl.arange(0, 2).to(dtype=tl.int32) ^ 1 | |
| xor = (bit[:, None] < K) & xor[None, :] | |
| p_max = tl.full((BLOCK_X,), dtype=tl.float32, value=-scale * K) | |
| n_max = tl.full((BLOCK_X,), dtype=tl.float32, value=-scale * K) | |
| n = tl.zeros((BLOCK_X,), dtype=tl.float32) | |
| p = tl.zeros((BLOCK_X,), dtype=tl.float32) | |
| if use_count: | |
| c = tl.zeros((BLOCK_X,), dtype=tl.float32) | |
| xi = tl.arange(0, BLOCK_X) | |
| yi = tl.arange(0, BLOCK_Y) | |
| for _ in tl.range(tl.cdiv(Y, BLOCK_Y)): | |
| y = tl.load(y_block, boundary_check=(0,), padding_option="zero").to(dtype=tl.int32) | |
| y = ((y[None, None, :] >> bit[:, None, None]) & 1) ^ xor[:, :, None] # [k, 2, y] | |
| y = tl.reshape(y, (BLOCK_K * 2, BLOCK_Y)).to(dtype=x.dtype) | |
| b = tl.load(b_block, boundary_check=(0,)) # [y] | |
| z = tl.dot(x, y, input_precision="ieee") # [x, y] | |
| mask = (xi[:, None] < X) & (yi[None, :] < Y) | |
| n_mask = mask & (a[:, None] != b[None, :]) | |
| p_mask = mask & (a[:, None] == b[None, :]) | |
| n_max, n = add_lse(n_max, n, n_mask, z) | |
| p_max, p = add_lse(p_max, p, p_mask, -z) | |
| if use_count: | |
| c += tl.sum(p_mask.to(dtype=tl.float32), axis=-1) | |
| y_block = tl.advance(y_block, (BLOCK_Y,)) | |
| b_block = tl.advance(b_block, (BLOCK_Y,)) | |
| yi += BLOCK_Y | |
| n = tl.log(n) + n_max | |
| p = tl.log(p) + p_max | |
| tl.store(n_block, n.to(dtype=n_ptr.dtype.element_ty), boundary_check=(0,)) | |
| tl.store(p_block, p.to(dtype=p_ptr.dtype.element_ty), boundary_check=(0,)) | |
| if use_count: | |
| tl.store(o_block, softplus_fwd(n + p - tl.log(c)).to(dtype=o_ptr.dtype.element_ty), boundary_check=(0,)) | |
| else: | |
| tl.store(o_block, softplus_fwd(n + p).to(dtype=o_ptr.dtype.element_ty), boundary_check=(0,)) | |
| def gather_lse_fwd(x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float, use_count: bool): | |
| n = torch.empty_like(x[:, 0, 0]) | |
| p = torch.empty_like(x[:, 0, 0]) | |
| o = torch.empty_like(x[:, 0, 0]) | |
| def grid(meta): | |
| return ( | |
| triton.cdiv(x.size(0), meta['BLOCK_X']), | |
| ) | |
| gather_lse_fwd_kernel[grid]( | |
| x, *x.stride(), | |
| y, *y.stride(), | |
| a, *a.stride(), | |
| b, *b.stride(), | |
| n, *n.stride(), | |
| p, *p.stride(), | |
| o, *o.stride(), | |
| x.size(0), | |
| y.size(0), | |
| x.size(1), | |
| scale, use_count, | |
| BLOCK_K=triton.next_power_of_2(x.size(1)), | |
| ) | |
| return n, p, o | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config(dict(BLOCK_X=BLOCK_X, BLOCK_Y=BLOCK_Y), num_stages=num_stages) | |
| for BLOCK_X in [16] | |
| for BLOCK_Y in [128] | |
| for num_stages in [3] | |
| ], | |
| key=[], | |
| ) | |
| @triton.jit | |
| def gather_lse_bwd_kernel( | |
| x_ptr: tl.tensor, x_s0: int, x_s1: int, x_s2: int, | |
| d_ptr: tl.tensor, d_s0: int, d_s1: int, d_s2: int, | |
| y_ptr: tl.tensor, y_s0: int, | |
| a_ptr: tl.tensor, a_s0: int, | |
| b_ptr: tl.tensor, b_s0: int, | |
| n_ptr: tl.tensor, n_s0: int, | |
| p_ptr: tl.tensor, p_s0: int, | |
| o_ptr: tl.tensor, o_s0: int, | |
| g_ptr: tl.tensor, g_s0: int, | |
| X: int, Y: int, K: int, scale: float, | |
| BLOCK_X: tl.constexpr, | |
| BLOCK_Y: tl.constexpr, | |
| BLOCK_K: tl.constexpr, | |
| ): | |
| x_block = tl.make_block_ptr( | |
| base=x_ptr, | |
| shape=(X, K, 2), | |
| strides=(x_s0, x_s1, x_s2), | |
| offsets=(BLOCK_X * tl.program_id(0), 0, 0), | |
| block_shape=(BLOCK_X, BLOCK_K, 2), | |
| order=(2, 1, 0), | |
| ) | |
| d_block = tl.make_block_ptr( | |
| base=d_ptr, | |
| shape=(X, K, 2), | |
| strides=(d_s0, d_s1, d_s2), | |
| offsets=(BLOCK_X * tl.program_id(0), 0, 0), | |
| block_shape=(BLOCK_X, BLOCK_K, 2), | |
| order=(2, 1, 0), | |
| ) | |
| y_block = tl.make_block_ptr( | |
| base=y_ptr, | |
| shape=(Y,), | |
| strides=(y_s0,), | |
| offsets=(0,), | |
| block_shape=(BLOCK_Y,), | |
| order=(0,), | |
| ) | |
| a_block = tl.make_block_ptr( | |
| base=a_ptr, | |
| shape=(X,), | |
| strides=(a_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,), | |
| ) | |
| b_block = tl.make_block_ptr( | |
| base=b_ptr, | |
| shape=(Y,), | |
| strides=(b_s0,), | |
| offsets=(0,), | |
| block_shape=(BLOCK_Y,), | |
| order=(0,), | |
| ) | |
| n_block = tl.make_block_ptr( | |
| base=n_ptr, | |
| shape=(X,), | |
| strides=(n_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,) | |
| ) | |
| p_block = tl.make_block_ptr( | |
| base=p_ptr, | |
| shape=(X,), | |
| strides=(p_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,) | |
| ) | |
| o_block = tl.make_block_ptr( | |
| base=o_ptr, | |
| shape=(X,), | |
| strides=(o_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,), | |
| ) | |
| g_block = tl.make_block_ptr( | |
| base=g_ptr, | |
| shape=(X,), | |
| strides=(g_s0,), | |
| offsets=(BLOCK_X * tl.program_id(0),), | |
| block_shape=(BLOCK_X,), | |
| order=(0,) | |
| ) | |
| x = tl.load(x_block, boundary_check=(0, 1, 2), padding_option="zero") # [x, k, 2] | |
| x = tl.reshape(x, (BLOCK_X, BLOCK_K * 2)) | |
| a = tl.load(a_block, boundary_check=(0,)) # [x] | |
| bit = tl.arange(0, BLOCK_K).to(dtype=tl.int32) | |
| xor = tl.arange(0, 2).to(dtype=tl.int32) ^ 1 | |
| xor = (bit[:, None] < K) & xor[None, :] | |
| n = tl.load(n_block, boundary_check=(0,), padding_option="zero") # [x] | |
| p = tl.load(p_block, boundary_check=(0,), padding_option="zero") # [x] | |
| g = tl.load(g_block, boundary_check=(0,), padding_option="zero") # [x] | |
| o = tl.load(o_block, boundary_check=(0,), padding_option="zero") # [x] | |
| g = g * softplus_bwd(o) | |
| d = tl.zeros((BLOCK_X, BLOCK_K * 2), dtype=tl.float32) | |
| xi = tl.arange(0, BLOCK_X) | |
| yi = tl.arange(0, BLOCK_Y) | |
| for _ in tl.range(tl.cdiv(Y, BLOCK_Y)): | |
| y = tl.load(y_block, boundary_check=(0,), padding_option="zero").to(dtype=tl.int32) | |
| y = ((y[None, None, :] >> bit[:, None, None]) & 1) ^ xor[:, :, None] # [k, 2, y] | |
| y = tl.reshape(y, (BLOCK_K * 2, BLOCK_Y)).to(dtype=x.dtype) * scale | |
| b = tl.load(b_block, boundary_check=(0,)) # [y] | |
| mask = (xi[:, None] < X) & (yi[None, :] < Y) | |
| z = tl.dot(x, y, input_precision="ieee") # [x, y] | |
| z1 = tl.where(mask & (a[:, None] != b[None, :]), z, -float("inf")) | |
| z2 = tl.where(mask & (a[:, None] == b[None, :]), -z, -float("inf")) | |
| d = tl.dot(g[:, None] * tl.exp(z1 - n[:, None]), tl.trans(y), acc=d, input_precision="ieee") # [x, k * 2] | |
| d = tl.dot(g[:, None] * tl.exp(z2 - p[:, None]), -tl.trans(y), acc=d, input_precision="ieee") # [x, k * 2] | |
| y_block = tl.advance(y_block, (BLOCK_Y,)) | |
| b_block = tl.advance(b_block, (BLOCK_Y,)) | |
| yi += BLOCK_Y | |
| d = tl.reshape(d, (BLOCK_X, BLOCK_K, 2)) | |
| tl.store(d_block, d.to(dtype=d_ptr.dtype.element_ty), boundary_check=(0, 1, 2)) | |
| def gather_lse_bwd( | |
| x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float, | |
| n: Tensor, p: Tensor, o: Tensor, g: Tensor): | |
| d = torch.empty_like(x) | |
| def grid(meta): | |
| return ( | |
| triton.cdiv(x.size(0), meta['BLOCK_X']), | |
| ) | |
| gather_lse_bwd_kernel[grid]( | |
| x, *x.stride(), | |
| d, *d.stride(), | |
| y, *y.stride(), | |
| a, *a.stride(), | |
| b, *b.stride(), | |
| n, *n.stride(), | |
| p, *p.stride(), | |
| o, *o.stride(), | |
| g, *g.stride(), | |
| x.size(0), | |
| y.size(0), | |
| x.size(1), | |
| scale, | |
| BLOCK_K=triton.next_power_of_2(x.size(1)), | |
| ) | |
| return d | |
| class GatherLse(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float) -> Tensor: | |
| n, p, o = gather_lse_fwd(x, y, a, b, scale, True) | |
| ctx.save_for_backward(x, y, a, b, n, p, o) | |
| ctx.scale = scale | |
| return o | |
| @staticmethod | |
| def backward(ctx: Any, g: Tensor): | |
| x, y, a, b, n, p, o = ctx.saved_tensors | |
| return gather_lse_bwd(x, y, a, b, ctx.scale, n, p, o, g), None, None, None, None | |
| def gather_lse(x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float) -> Tensor: | |
| return GatherLse.apply(x, y, a, b, scale) | |
| def expected_gather_lse(x: Tensor, y: Tensor, a: Tensor, b: Tensor, scale: float) -> Tensor: | |
| index = torch.arange(x.size(1), device=x.device) # [k] | |
| y = (y[:, None] >> index) & 1 # [y, k] | |
| s = torch.gather( | |
| input=x[:, None, :, :].expand((-1, y.size(0), -1, -1)), | |
| index=y[None, :, :, None].expand((x.size(0), -1, -1, -1)), | |
| dim=-1, | |
| ) | |
| s = s.sum(dim=[-1, -2]).mul(scale) | |
| n = torch.masked_fill(s, a[:, None] == b[None, :], -float("inf")) | |
| p = torch.masked_fill(-s, a[:, None] != b[None, :], -float("inf")) | |
| c = (a[:, None] == b[None, :]).float().sum(dim=-1) | |
| return F.softplus(n.logsumexp(dim=-1) + p.logsumexp(dim=-1) - c.log()) | |
| if __name__ == '__main__': | |
| X = 1000 | |
| Y = 20000 | |
| K = 31 | |
| V = 1024 | |
| scale = 1 | |
| x = torch.rand((X, K, 2), requires_grad=True, device='cuda') | |
| y = torch.randint(1 << K, (Y,), device='cuda') | |
| a = torch.randint(V, (X,), device='cuda') | |
| b = torch.cat([a, torch.randint(V, (Y - X,), device='cuda')], dim=0) | |
| actual = gather_lse(x, y, a, b, scale) | |
| expected = expected_gather_lse(x, y, a, b, scale) | |
| assert_close(actual, expected) | |
| grad = torch.rand_like(actual) | |
| actual_grad, = torch.autograd.grad(actual, x, grad) | |
| expected_grad, = torch.autograd.grad(expected, x, grad) | |
| assert_close(actual_grad, expected_grad) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment