Skip to content

Instantly share code, notes, and snippets.

@keivalya
Created December 16, 2025 23:59
Show Gist options
  • Select an option

  • Save keivalya/dc1510a7392bdd4f300fcefa650164ea to your computer and use it in GitHub Desktop.

Select an option

Save keivalya/dc1510a7392bdd4f300fcefa650164ea to your computer and use it in GitHub Desktop.
The Fusion module (that fuses image, text and state tokens)
"""Fusion module to combine image, text, and state tokens."""
import torch
import torch.nn as nn
class FusionMLP(nn.Module):
def __init__(self, d_model=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(3 * d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model),
)
self.ln = nn.LayerNorm(d_model)
def forward(self, img_token, txt_token, state_token):
x = torch.cat([img_token, txt_token, state_token], dim=-1)
x = self.net(x)
x = self.ln(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment