Skip to content

Instantly share code, notes, and snippets.

@fern89
Created May 7, 2025 11:14
Show Gist options
  • Select an option

  • Save fern89/8817c92df4090a3d179213c151c2cc2f to your computer and use it in GitHub Desktop.

Select an option

Save fern89/8817c92df4090a3d179213c151c2cc2f to your computer and use it in GitHub Desktop.
Implementation of Mamba in Pytorch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
import torch, math
import torch.nn.functional as F
def selective_scan(u, dt, A, B, C, D):
dA = torch.einsum('bld,dn->bldn', dt, A)
dB_u = torch.einsum('bld,bld,bln->bldn', dt, u, B)
dA = dA.clamp(min=-20)
padding = (0, 0, 0, 0, 1, 0)
dA_cumsum = F.pad(dA[:, 1:], padding).cumsum(1).exp()
x = dB_u / (dA_cumsum + 1e-12)
x = x.cumsum(1) * dA_cumsum
y = torch.einsum('bldn,bln->bld', x, C)
return y + u * D
class MambaBlock(nn.Module):
def __init__(self, embed_dim, inner_dim, state_dim, delta_rank):
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
super().__init__()
self.embed_dim = embed_dim
self.inner_dim = inner_dim
self.state_dim = state_dim
self.delta_rank = delta_rank
self.in_proj = nn.Linear(embed_dim, inner_dim * 2, bias=False)
self.conv1d = nn.Conv1d(
in_channels=inner_dim,
out_channels=inner_dim,
kernel_size=4,
groups=inner_dim,
padding=3,
)
# x_proj takes in `x` and outputs the input-specific Δ, B, C
self.x_proj = nn.Linear(inner_dim, delta_rank + state_dim * 2, bias=False)
# dt_proj projects Δ from dt_rank to d_in
self.dt_proj = nn.Linear(delta_rank, inner_dim, bias=True)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.delta_rank**-0.5
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(inner_dim) * (math.log(0.1) - math.log(0.001))
+ math.log(0.001)
).clamp(min=1e-4)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
self.dt_proj.bias._no_reinit = True
A = torch.arange(1, state_dim + 1).unsqueeze(0).repeat(inner_dim, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(inner_dim))
self.out_proj = nn.Linear(inner_dim, embed_dim, bias=False)
def forward(self, x):
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.inner_dim, self.inner_dim], dim=-1)
x = x.transpose(-1, -2)
x = self.conv1d(x)[:, :, :l]
x = x.transpose(-1, -2)
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
return self.out_proj(y)
def ssm(self, x):
(d_in, n) = self.A_log.shape
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.delta_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
return selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
class Model(nn.Module):
def __init__(self, embed_dim, inner_dim, state_dim, n_layers):
super().__init__()
self.mambas = nn.ModuleList([])
self.embeds = nn.Linear(1, embed_dim)
self.outs = nn.Linear(embed_dim, 10)
for i in range(n_layers):
self.mambas.append(
nn.Sequential(
nn.modules.normalization.RMSNorm(embed_dim),
MambaBlock(embed_dim, inner_dim, state_dim, 1)
)
)
self.final_norm = nn.modules.normalization.RMSNorm(embed_dim)
def forward(self, x):
x = x.flatten(1).unsqueeze(-1)
x = self.embeds(x)
for mamba in self.mambas:
x = x + mamba(x)
x = x[:, -1, :]
x = self.final_norm(x)
x = self.outs(x)
return x
model = Model(embed_dim = 8, state_dim = 128, inner_dim = 32, n_layers = 4).cuda()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((10,10)),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)
print("ready")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 3e-3)
for i in range(10):
for j, (inp, trg) in enumerate(train_loader):
optimizer.zero_grad()
outs = model(inp.cuda())
loss = criterion(outs, trg.cuda())
if j % 10 == 0:
print(loss.item())
acc = (outs.cpu().argmax(dim = -1) == trg).sum() / len(trg)
print(acc)
loss.backward()
optimizer.step()
total = 0
correct = 0
with torch.no_grad():
for _, (inp, trg) in enumerate(test_loader):
outs = model(inp.cuda())
correct += (outs.cpu().argmax(dim = -1) == trg).sum()
total += len(trg)
print(correct / total)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment