Cargo.toml:
[package]
name = "microgpt"
version = "0.1.0"
edition = "2024"
[dependencies]
tch = "0.13" # ensure your LIBTORCH path is set correctly
anyhow = "1.0"src/main.rs:
//
// Rust impl of Andrej Karpathy's microgpt.py (https://gist.github.com/karpathy/8627fe009c40f57531cb18360106ce95 ) using the tch-rs crate
//
use anyhow::Result;
use tch::{Device, Kind, Tensor, nn, nn::ModuleT, nn::OptimizerConfig};
// --- Configuration ---
struct GPTConfig {
vocab_size: i64,
block_size: i64,
n_layer: i64,
n_head: i64,
n_embd: i64,
dropout: f64,
}
impl GPTConfig {
fn default() -> Self {
Self {
vocab_size: 65, // Character level vocab size for Shakespeare
block_size: 128, // Context length
n_layer: 4,
n_head: 4,
n_embd: 128,
dropout: 0.1,
}
}
}
// --- Modules ---
struct CausalSelfAttention {
c_attn: nn::Linear, // Combined Q, K, V projection
c_proj: nn::Linear, // Output projection
n_head: i64,
n_embd: i64,
dropout: f64,
}
impl CausalSelfAttention {
fn new(vs: &nn::Path, config: &GPTConfig) -> Self {
let c_attn = nn::linear(
vs,
config.n_embd,
3 * config.n_embd, // Output size is 3x for Q, K, V
Default::default(),
);
let c_proj = nn::linear(vs, config.n_embd, config.n_embd, Default::default());
Self {
c_attn,
c_proj,
n_head: config.n_head,
n_embd: config.n_embd,
dropout: config.dropout,
}
}
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
let (b, t, c) = xs.size3().unwrap(); // Batch, Time, Channels
// Calculate Q, K, V
let qkv = xs.apply(&self.c_attn);
let (q, k, v) = qkv.split(3, 1); // Split along the last dimension? No, Python uses dim -1 (last).
// tch split takes dim index. Last dim is -1 or 2.
let (q, k, v) = qkv.split(3, 2); // Split into 3 chunks on last dim
// Reshape for heads: (B, T, n_head, head_dim) -> (B, n_head, T, head_dim)
let head_dim = self.n_embd / self.n_head;
// Note: tch view requires calculating shape manually
let q = q.view((b, t, self.n_head, head_dim)).transpose(1, 2);
let k = k.view((b, t, self.n_head, head_dim)).transpose(1, 2);
let v = v.view((b, t, self.n_head, head_dim)).transpose(1, 2);
// Attention (Q @ K.T / sqrt(d))
let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / (head_dim as f64).sqrt());
// Causal Mask
// Create a lower triangular mask
let mask: Tensor = Tensor::tril(&Tensor::ones([t, t], (Kind::Float, xs.device())), 0);
let mask = mask.view((1, 1, t, t)); // Broadcast to (B, n_head, T, T)
// Apply mask: set upper triangle to -inf
let att = att.masked_fill(&mask.eq(0.), f64::NEG_INFINITY);
// Softmax and Dropout
let att = att.softmax(-1, Kind::Float);
let att = att.dropout_t(self.dropout, train);
// Aggregate values
let y = att.matmul(&v);
// Re-assemble heads: (B, n_head, T, head_dim) -> (B, T, n_embd)
let y = y.transpose(1, 2).contiguous().view((b, t, c));
// Output projection
y.apply(&self.c_proj)
}
}
struct MLP {
c_fc: nn::Linear,
c_proj: nn::Linear,
dropout: f64,
}
impl MLP {
fn new(vs: &nn::Path, config: &GPTConfig) -> Self {
let c_fc = nn::linear(vs, config.n_embd, 4 * config.n_embd, Default::default());
let c_proj = nn::linear(vs, 4 * config.n_embd, config.n_embd, Default::default());
Self {
c_fc,
c_proj,
dropout: config.dropout,
}
}
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
xs.apply(&self.c_fc)
.gelu("none") // PyTorch default approx is 'none' usually or tanh. Karpathy uses standard GELU.
.apply(&self.c_proj)
.dropout_t(self.dropout, train)
}
}
struct Block {
ln_1: nn::LayerNorm,
attn: CausalSelfAttention,
ln_2: nn::LayerNorm,
mlp: MLP,
}
impl Block {
fn new(vs: &nn::Path, config: &GPTConfig) -> Self {
Self {
ln_1: nn::layer_norm(vs, vec![config.n_embd], Default::default()),
attn: CausalSelfAttention::new(&(vs / "attn"), config),
ln_2: nn::layer_norm(vs, vec![config.n_embd], Default::default()),
mlp: MLP::new(&(vs / "mlp"), config),
}
}
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
// Pre-norm architecture
let x = xs + self.attn.forward_t(&xs.apply(&self.ln_1), train);
x + self.mlp.forward_t(&x.apply(&self.ln_2), train)
}
}
struct GPT {
wte: nn::Embedding, // Token embeddings
wpe: nn::Embedding, // Position embeddings
blocks: Vec<Block>,
ln_f: nn::LayerNorm,
lm_head: nn::Linear, // We will tie weights manually in forward
config: GPTConfig,
}
impl GPT {
fn new(vs: &nn::Path, config: GPTConfig) -> Self {
let wte = nn::embedding(vs, config.vocab_size, config.n_embd, Default::default());
let wpe = nn::embedding(vs, config.block_size, config.n_embd, Default::default());
// Create transformer blocks
let blocks: Vec<Block> = (0..config.n_layer)
.map(|i| Block::new(&(vs / "blocks" / &i.to_string()), &config))
.collect();
let ln_f = nn::layer_norm(vs, vec![config.n_embd], Default::default());
// Note: Karpathy's Python code ties weights between wte and lm_head.
// In Rust/tch, we create a separate linear layer but will manually use wte weights
// in the forward pass to simulate weight tying.
let lm_head = nn::linear(vs, config.n_embd, config.vocab_size, Default::default());
Self {
wte,
wpe,
blocks,
ln_f,
lm_head,
config,
}
}
fn forward_t(
&self,
idx: &Tensor,
targets: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (b, t) = idx.size2().unwrap();
assert!(
t <= self.config.block_size,
"Sequence length exceeds block size"
);
// 1. Embeddings
let pos = Tensor::arange(t, (Kind::Int64, idx.device()));
let tok_emb = idx.apply(&self.wte); // (B, T, n_embd)
let pos_emb = pos.apply(&self.wpe); // (T, n_embd) -> broadcasts to (B, T, n_embd)
let mut x = &tok_emb + &pos_emb; // (B, T, n_embd)
// 2. Transformer Blocks
for block in &self.blocks {
x = block.forward_t(&x, train);
}
// 3. Final LayerNorm
let x = x.apply(&self.ln_f);
// 4. Projection to Vocab
// Karpathy uses weight tying: self.lm_head.weight = self.wte.weight
// To simulate this in tch, we perform the matrix multiplication manually:
// Logits = x @ wte.weight.T
// wte.ws is shape (vocab, embd). We need transpose.
let logits = x.matmul(&self.wte.ws.transpose(0, 1));
// 5. Loss
let loss = if let Some(targets) = targets {
// CrossEntropyLoss expects (N, C) for input and (N) for target
let (b, t, c) = logits.size3().unwrap();
let logits_flat = logits.view((b * t, c));
let targets_flat = targets.view((b * t));
Some(logits_flat.cross_entropy_for_logits(&targets_flat))
} else {
None
};
(logits, loss)
}
}
// --- Main Training Loop ---
fn main() -> Result<()> {
// Setup device
let device = Device::cuda_if_available();
println!("Running on device: {:?}", device);
// Config
let config = GPTConfig::default();
// Data setup (Tiny Shakespeare snippet hardcoded for overfit test)
// In a real app, you'd load a file. Here we just use a string to verify the model learns.
let text = "First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\n";
// Create simple character-level mappings
let chars: Vec<char> = text.chars().collect();
let unique_chars: Vec<char> = {
let mut set = std::collections::HashSet::new();
let mut unique = Vec::new();
for c in &chars {
if set.insert(*c) {
unique.push(*c);
}
}
unique
};
let vocab_size = unique_chars.len() as i64;
let stoi: std::collections::HashMap<char, i64> = unique_chars
.iter()
.enumerate()
.map(|(i, &c)| (c, i as i64))
.collect();
// Encode text to tensor
let data: Vec<i64> = chars.iter().map(|c| *stoi.get(c).unwrap()).collect();
let data_tensor = Tensor::from_slice(&data).to(device);
println!("Vocab size: {}", vocab_size);
println!("Data length: {} chars", chars.len());
// Initialize Model
let vs = nn::VarStore::new(device);
let mut config = GPTConfig::default();
config.vocab_size = vocab_size; // Update config with actual vocab size
let model = GPT::new(&vs.root(), config);
// Optimizer
let mut opt = nn::AdamW::default().build(&vs, 3e-4)?;
// Training Loop
println!("Starting training...");
let block_size = config.block_size;
for epoch in 0..2000 {
// Sample a random chunk of data
let ix = Tensor::randint(chars.len() as i64 - block_size, (1,), (Kind::Int64, device));
let start = i64::from(ix.get(0));
let x = data_tensor.narrow(0, start, block_size).unsqueeze(0);
let y = data_tensor.narrow(0, start + 1, block_size).unsqueeze(0);
// Forward pass
let (_, loss) = model.forward_t(&x, Some(&y), true);
let loss_val = f64::from(loss.as_ref().unwrap());
// Backward pass
opt.backward_step(&loss.unwrap());
if epoch % 100 == 0 {
println!("Epoch {}: Loss = {:.4}", epoch, loss_val);
}
}
// --- Generation Test ---
println!("\n--- Generating Sample ---");
model.set_train(false); // Set to eval mode (disables dropout)
// Start with newline token or index 0
let mut idx = Tensor::zeros((1, 1), (Kind::Int64, device));
// Generate 100 tokens
for _ in 0..100 {
let idx_cond = if idx.size()[1] > block_size {
idx.narrow(1, idx.size()[1] - block_size, block_size)
} else {
idx.shallow_clone()
};
let (logits, _) = model.forward_t(&idx_cond, None, false);
// Take the last time step
let logits = logits.get(-1).get(0); // (Vocab)
// Softmax
let probs = logits.softmax(-1, Kind::Float);
// Sample
let next_ix = probs.multinomial(1, true); // Sample 1 token
// Append
idx = Tensor::cat(&[&idx, &next_ix], 1);
}
// Decode and print
let idx_vec: Vec<i64> = idx.get(0).try_into().unwrap();
let generated: String = idx_vec
.iter()
.map(|&i| {
if (i as usize) < unique_chars.len() {
unique_chars[i as usize]
} else {
'?'
}
})
.collect();
println!("Generated:\n{}", generated);
Ok(())
}
Source: Andrej Karpathy blog's microgpt.