Created
May 7, 2025 11:14
-
-
Save fern89/8817c92df4090a3d179213c151c2cc2f to your computer and use it in GitHub Desktop.
Implementation of Mamba in Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch.optim as optim | |
| import torch.nn as nn | |
| from torchvision import datasets, transforms | |
| import torch, math | |
| import torch.nn.functional as F | |
| def selective_scan(u, dt, A, B, C, D): | |
| dA = torch.einsum('bld,dn->bldn', dt, A) | |
| dB_u = torch.einsum('bld,bld,bln->bldn', dt, u, B) | |
| dA = dA.clamp(min=-20) | |
| padding = (0, 0, 0, 0, 1, 0) | |
| dA_cumsum = F.pad(dA[:, 1:], padding).cumsum(1).exp() | |
| x = dB_u / (dA_cumsum + 1e-12) | |
| x = x.cumsum(1) * dA_cumsum | |
| y = torch.einsum('bldn,bln->bld', x, C) | |
| return y + u * D | |
| class MambaBlock(nn.Module): | |
| def __init__(self, embed_dim, inner_dim, state_dim, delta_rank): | |
| """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.inner_dim = inner_dim | |
| self.state_dim = state_dim | |
| self.delta_rank = delta_rank | |
| self.in_proj = nn.Linear(embed_dim, inner_dim * 2, bias=False) | |
| self.conv1d = nn.Conv1d( | |
| in_channels=inner_dim, | |
| out_channels=inner_dim, | |
| kernel_size=4, | |
| groups=inner_dim, | |
| padding=3, | |
| ) | |
| # x_proj takes in `x` and outputs the input-specific Δ, B, C | |
| self.x_proj = nn.Linear(inner_dim, delta_rank + state_dim * 2, bias=False) | |
| # dt_proj projects Δ from dt_rank to d_in | |
| self.dt_proj = nn.Linear(delta_rank, inner_dim, bias=True) | |
| # Initialize special dt projection to preserve variance at initialization | |
| dt_init_std = self.delta_rank**-0.5 | |
| nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) | |
| # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max | |
| dt = torch.exp( | |
| torch.rand(inner_dim) * (math.log(0.1) - math.log(0.001)) | |
| + math.log(0.001) | |
| ).clamp(min=1e-4) | |
| # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 | |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) | |
| with torch.no_grad(): | |
| self.dt_proj.bias.copy_(inv_dt) | |
| self.dt_proj.bias._no_reinit = True | |
| A = torch.arange(1, state_dim + 1).unsqueeze(0).repeat(inner_dim, 1) | |
| self.A_log = nn.Parameter(torch.log(A)) | |
| self.D = nn.Parameter(torch.ones(inner_dim)) | |
| self.out_proj = nn.Linear(inner_dim, embed_dim, bias=False) | |
| def forward(self, x): | |
| (b, l, d) = x.shape | |
| x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) | |
| (x, res) = x_and_res.split(split_size=[self.inner_dim, self.inner_dim], dim=-1) | |
| x = x.transpose(-1, -2) | |
| x = self.conv1d(x)[:, :, :l] | |
| x = x.transpose(-1, -2) | |
| x = F.silu(x) | |
| y = self.ssm(x) | |
| y = y * F.silu(res) | |
| return self.out_proj(y) | |
| def ssm(self, x): | |
| (d_in, n) = self.A_log.shape | |
| A = -torch.exp(self.A_log.float()) # shape (d_in, n) | |
| D = self.D.float() | |
| x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) | |
| (delta, B, C) = x_dbl.split(split_size=[self.delta_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) | |
| delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) | |
| return selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] | |
| class Model(nn.Module): | |
| def __init__(self, embed_dim, inner_dim, state_dim, n_layers): | |
| super().__init__() | |
| self.mambas = nn.ModuleList([]) | |
| self.embeds = nn.Linear(1, embed_dim) | |
| self.outs = nn.Linear(embed_dim, 10) | |
| for i in range(n_layers): | |
| self.mambas.append( | |
| nn.Sequential( | |
| nn.modules.normalization.RMSNorm(embed_dim), | |
| MambaBlock(embed_dim, inner_dim, state_dim, 1) | |
| ) | |
| ) | |
| self.final_norm = nn.modules.normalization.RMSNorm(embed_dim) | |
| def forward(self, x): | |
| x = x.flatten(1).unsqueeze(-1) | |
| x = self.embeds(x) | |
| for mamba in self.mambas: | |
| x = x + mamba(x) | |
| x = x[:, -1, :] | |
| x = self.final_norm(x) | |
| x = self.outs(x) | |
| return x | |
| model = Model(embed_dim = 8, state_dim = 128, inner_dim = 32, n_layers = 4).cuda() | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((10,10)), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) | |
| test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) | |
| train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=256, shuffle=True) | |
| test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False) | |
| print("ready") | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr = 3e-3) | |
| for i in range(10): | |
| for j, (inp, trg) in enumerate(train_loader): | |
| optimizer.zero_grad() | |
| outs = model(inp.cuda()) | |
| loss = criterion(outs, trg.cuda()) | |
| if j % 10 == 0: | |
| print(loss.item()) | |
| acc = (outs.cpu().argmax(dim = -1) == trg).sum() / len(trg) | |
| print(acc) | |
| loss.backward() | |
| optimizer.step() | |
| total = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for _, (inp, trg) in enumerate(test_loader): | |
| outs = model(inp.cuda()) | |
| correct += (outs.cpu().argmax(dim = -1) == trg).sum() | |
| total += len(trg) | |
| print(correct / total) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment