Skip to content

Instantly share code, notes, and snippets.

@keivalya
Created December 16, 2025 15:56
Show Gist options
  • Select an option

  • Save keivalya/d31be7d61d6fcf48f55955643c1a975a to your computer and use it in GitHub Desktop.

Select an option

Save keivalya/d31be7d61d6fcf48f55955643c1a975a to your computer and use it in GitHub Desktop.
Text Encoder.
# Encoding text using gated-recurrent unit
class TextEncoderTinyGRU(nn.Module):
def __init__(self, vocab_size, d_word=64, d_model=128):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_word)
self.gru = nn.GRU(d_word, d_model, batch_first=True)
self.ln = nn.LayerNorm(d_model)
def forward(self, token_ids):
x = self.embed(token_ids)
_, h_last = self.gru(x)
x = h_last[0]
x = self.ln(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment