Skip to content

Instantly share code, notes, and snippets.

@anujkhare
Last active December 23, 2025 19:03
Show Gist options
  • Select an option

  • Save anujkhare/d999793164fa4162dd7cb8ffdacfa090 to your computer and use it in GitHub Desktop.

Select an option

Save anujkhare/d999793164fa4162dd7cb8ffdacfa090 to your computer and use it in GitHub Desktop.
# ============================================================
# External Imports
# ============================================================
from jax import numpy as jnp
from transformers import AutoTokenizer
import abc
import dataclasses
import datetime
import einops
import jax
import jaxtyping as jt
import math
import numpy as np
import optax
import orbax.checkpoint as ocp
import pathlib
import pyarrow.parquet as pq
import typing
# ============================================================
# toylib_projects.tinystories.data - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/data.py
# ============================================================
@dataclasses.dataclass
class BatchedTokenizedDataset(abc.ABC):
dataset_path: str = "karpathy/fineweb-edu-100b-shuffle"
split: str = "train"
tokenizer_name: str = "gpt2"
seq_len: int = 2048
tokenizer_batch_size: int = 8
batch_size: int = 128
@abc.abstractmethod
def _get_dataset_iterator(self) -> typing.Iterator:
raise NotImplementedError
def __post_init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
self.bos_token = self.tokenizer.bos_token_id
self.token_buffer = []
self.dataset_iter = self._get_dataset_iterator()
def __iter__(self):
return self
def __next__(self) -> jnp.ndarray:
token_needed = self.batch_size * self.seq_len + 1
while len(self.token_buffer) < token_needed:
input_batch = next(self.dataset_iter)
texts = input_batch["text"]
tokenized = self.tokenizer(
texts,
return_tensors=None,
padding=False,
truncation=False,
max_length=None,
)["input_ids"]
for tokens in tokenized:
self.token_buffer.append(self.bos_token)
self.token_buffer.extend(tokens)
tokens = self.token_buffer[:token_needed]
self.token_buffer = self.token_buffer[token_needed:]
inputs = jnp.array(tokens[:-1], dtype=jnp.uint16).reshape(
self.batch_size, self.seq_len
)
targets = jnp.array(tokens[1:], dtype=jnp.uint16).reshape(
self.batch_size, self.seq_len
)
return {"inputs": inputs, "targets": targets}
class BatchedTokenizedDatasetParquet(BatchedTokenizedDataset):
"""Path is constructed as dataset_path/split/*.parquet"""
def list_files(self):
base_path = pathlib.Path(self.dataset_path) / self.split
return list(base_path.glob("*.parquet"))
def _get_dataset_iterator(self):
for file_path in self.list_files():
pf = pq.ParquetFile(file_path)
for row_group in range(pf.num_row_groups):
rg = pf.read_row_group(row_group)
yield {"text": rg.column("text").to_pylist()}
# ============================================================
# toylib.nn.module - /Users/anuj/Desktop/code/toylib/toylib/nn/module.py
# ============================================================
def _is_array(x: typing.Any) -> bool:
return isinstance(x, (jax.Array, np.ndarray, np.generic)) or hasattr(
x, "__jax_array__"
)
def _is_random_key(x: str) -> bool:
return x == "key"
def _is_supported_container(x: typing.Any) -> bool:
return isinstance(x, (list, tuple))
class Module(abc.ABC):
"""
Defines a base class to use for the neural network modules in toylib.
Assumes that all jax arrays are leaf nodes that are trainable and
everything else is a static param. Defines the flatten and unflatten methods
to make the modules compatible with jax `jit` and `grad` functions.
Refer https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees
Inspired by equinox and the Custom PyTres and Initialization section in jax docs.
"""
def __init_subclass__(cls, **kwargs: typing.Any) -> None:
super().__init_subclass__(**kwargs)
cls = dataclasses.dataclass(cls, kw_only=True)
cls = jax.tree_util.register_pytree_with_keys_class(cls)
@abc.abstractmethod
def init(self) -> None:
"""Initialize all the trainable parameters in the module."""
pass
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.Any:
"""Run a forward pass of the module."""
pass
def _get_trainable_param_keys(self) -> list[str]:
"""Get the list of attribute names that are trainable parameters."""
param_keys = []
for k, v in self.__dict__.items():
if (
_is_array(v)
and (not _is_random_key(k))
or isinstance(v, Module)
or (
_is_supported_container(v)
and all((isinstance(elem, Module) for elem in v))
)
):
param_keys.append(k)
return param_keys
def __post_init__(self) -> None:
self.init()
self._trainable_param_keys = self._get_trainable_param_keys()
def tree_flatten_with_keys(self) -> tuple:
params_with_keys = []
aux_data = dict()
for k, v in self.__dict__.items():
if k not in self._trainable_param_keys:
aux_data[k] = v
for k in self._trainable_param_keys:
v = self.__dict__[k]
params_with_keys.append((jax.tree_util.GetAttrKey(k), v))
return (params_with_keys, aux_data)
@classmethod
def tree_unflatten(cls, static, dynamic) -> "Module":
obj = object.__new__(cls)
param_keys = static["_trainable_param_keys"]
for k, v in zip(param_keys, dynamic):
obj.__setattr__(k, v)
for k, v in static.items():
obj.__setattr__(k, v)
return obj
# ============================================================
# toylib.nn.layers - /Users/anuj/Desktop/code/toylib/toylib/nn/layers.py
# ============================================================
class Linear(Module):
"""Defines a simple feedforward layer: which is a linear transformation."""
in_features: int
out_features: int
use_bias: bool = False
key: jt.PRNGKeyArray
weights: jt.Float[jt.Array, "in_features out_features"] | None = None
bias: typing.Optional[jt.Float[jt.Array, " out_features"]] | None = None
def init(self) -> None:
w_key = self.key
in_features = self.in_features
out_features = self.out_features
std = min(1.0, math.sqrt(out_features / in_features)) / math.sqrt(in_features)
self.weights = jax.random.normal(w_key, (in_features, out_features)) * std
self.bias = jax.numpy.zeros((out_features,)) if self.use_bias else None
def __call__(
self, x: jt.Float[jt.Array, "... in_features"]
) -> jt.Float[jt.Array, "... out_features"]:
x = jax.numpy.dot(x, self.weights)
if self.use_bias:
x = x + self.bias
return x
class Embedding(Module):
"""Defines an embedding layer that stores an embedding matrix for discrete tokens."""
vocab_size: int
embedding_dim: int
key: jt.PRNGKeyArray
weights: jt.Float[jt.Array, "vocab_size embedding_dim"] | None = None
def init(self) -> None:
self.weights = jax.random.normal(
self.key, (self.vocab_size, self.embedding_dim)
)
def __call__(
self, tokens: jt.Int[jt.Array, "... seq_len"]
) -> jt.Float[jt.Array, "... seq_len embedding_dim"]:
return jax.numpy.take(self.weights, tokens, axis=0)
def rms_norm(x: jt.Float[jt.Array, "... dim"]) -> jt.Float[jt.Array, "... dim"]:
"""Applies RMS Normalization over the last dimension of the input tensor.
Args:
x: Input tensor
Returns:
The RMS normalized tensor of the same shape as input x.
"""
rms = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + 1e-09)
return x / rms
# ============================================================
# toylib.nn.attention - /Users/anuj/Desktop/code/toylib/toylib/nn/attention.py
# ============================================================
class RotaryPositionalEmbedding(Module):
"""Implements Rotary Positional Embeddings (RoPE) as described in https://arxiv.org/abs/2104.09864."""
seq_len: int = 1024
qkv_dim: int = 128
base: int = 100000
def init(self) -> None:
positions = jnp.arange(0, self.seq_len)
freqs = self.base ** (jnp.arange(0, self.qkv_dim, 2) / self.qkv_dim)
self.gamma = einops.einsum(positions, 1.0 / freqs, "t, d -> t d")
self.cos = jnp.cos(self.gamma)
self.sin = jnp.sin(self.gamma)
def __call__(
self, x: jt.Float[jt.Array, "... seq_len qkv_dim"]
) -> jt.Float[jt.Array, "... seq_len qkv_dim"]:
d = x.shape[-1]
x1, x2 = (x[..., : d // 2], x[..., d // 2 :])
es_shape = "... t d, t d -> ... t d"
y1 = einops.einsum(x1, self.cos, es_shape) + einops.einsum(
x2, self.sin, es_shape
)
y2 = -einops.einsum(x1, self.sin, es_shape) + einops.einsum(
x2, self.cos, es_shape
)
return jnp.concatenate([y1, y2], axis=-1)
def scaled_dot_product_attention(
q: jt.Float[jt.Array, "... seq_len qkv_dim"],
k: jt.Float[jt.Array, "... seq_len qkv_dim"],
v: jt.Float[jt.Array, "... seq_len qkv_dim"],
mask: typing.Optional[jt.Float[jt.Array, "... seq_len seq_len"]],
) -> tuple[
jt.Float[jt.Array, "... seq_len qkv_dim"], jt.Float[jt.Array, "... seq_len seq_len"]
]:
"""Compute scaled dot product attention.
Given query (`q`), key (`k`), and value (`v`) tensors, this function first computes the
attention weights as the softmax of the dot product of `q` and `k`, scaled by the square
root of the dimension of the keys. If a mask is provided, it is applied to the attention
logits before the softmax is computed.
Finally, the attention weights are used to compute the weighted average of the given values.
NOTE: the batch dimension is not explicitly handled in this function.
Args:
q: query tensor
k: keys tensor
v: values tensor
mask: optional boolean mask to apply to the attention logits
Returns:
tuple of final values and attention weights
"""
d_k = q.shape[-1]
assert q.shape[-1] == k.shape[-1], "q and k must have the same feature dimension"
attention_logits = jnp.matmul(q, k.swapaxes(-1, -2)) / jnp.sqrt(d_k)
if mask is not None:
attention_logits = jnp.where(mask, attention_logits, -1000000000.0)
attention_weights = jax.nn.softmax(attention_logits, axis=-1)
values = jnp.matmul(attention_weights, v)
return (values, attention_weights)
class MultiHeadAttention(Module):
"""
The MultiHeadAttention defines `num_heads` attention heads. For the given input `Q`, `K`, `V`
tensors, `num_head` linear projections of dim `qkv_dim / num_heads` are produced.
An attention weight is then computed using the scaled dot product attention method. The
weighted average of the values are then concatenated from the various heads to produce a
single output value vector. A final linear layer is applied on top of this with non-linearity.
"""
qkv_dim: int
num_heads: int
use_qk_norm: bool = True
key: jt.PRNGKeyArray
def init(self) -> None:
qkv_dim = self.qkv_dim
keys = jax.random.split(self.key, 4)
self.q_projection = Linear(
in_features=qkv_dim, out_features=qkv_dim, use_bias=False, key=keys[0]
)
self.k_projection = Linear(
in_features=qkv_dim, out_features=qkv_dim, use_bias=False, key=keys[1]
)
self.v_projection = Linear(
in_features=qkv_dim, out_features=qkv_dim, use_bias=False, key=keys[2]
)
self.linear = Linear(in_features=qkv_dim, out_features=qkv_dim, key=keys[3])
def __call__(
self,
Q: jt.Float[jt.Array, "... seq_len qkv_dim"],
K: jt.Float[jt.Array, "... seq_len qkv_dim"],
V: jt.Float[jt.Array, "... seq_len qkv_dim"],
mask: typing.Optional[jt.Float[jt.Array, "... seq_len seq_len"]] = None,
*,
rope: typing.Optional[RotaryPositionalEmbedding] = None,
return_attention_weights: bool = False,
) -> typing.Union[
tuple[
jt.Float[jt.Array, "... seq_len qkv_dim"],
jt.Float[jt.Array, "... seq_len seq_len"],
],
jt.Float[jt.Array, "... seq_len qkv_dim"],
]:
Q = self.q_projection(Q)
K = self.k_projection(K)
V = self.v_projection(V)
Q = einops.rearrange(
Q,
"... seq_len (num_heads head_dim) -> ... num_heads seq_len head_dim",
num_heads=self.num_heads,
)
K = einops.rearrange(
K,
"... seq_len (num_heads head_dim) -> ... num_heads seq_len head_dim",
num_heads=self.num_heads,
)
V = einops.rearrange(
V,
"... seq_len (num_heads head_dim) -> ... num_heads seq_len head_dim",
num_heads=self.num_heads,
)
if mask is not None:
mask = einops.rearrange(
mask, "... seq_len1 seq_len2 -> ... 1 seq_len1 seq_len2"
)
if rope is not None:
Q = rope(Q)
K = rope(K)
if self.use_qk_norm:
Q = rms_norm(Q)
K = rms_norm(K)
values, attention_weights = scaled_dot_product_attention(
q=Q, k=K, v=V, mask=mask
)
values = einops.rearrange(
values, "... num_heads seq_len d -> ... seq_len (num_heads d)"
)
values = self.linear(values)
if return_attention_weights:
return (values, attention_weights)
return values
# ============================================================
# toylib_projects.tinystories.decoder_only_model - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/decoder_only_model.py
# ============================================================
@dataclasses.dataclass
class ModelConfig:
"""Configuration for the DecoderOnlyTransformer model."""
num_layers: int = 2
num_heads: int = 8
qkv_dim: int = 256
vocab_size: int = 50257
seq_len: int = 512
logit_softcap: float = 15.0
class MLP(Module):
"""A simple feedforward MLP with one hidden layer."""
qkv_dim: int
key: jt.PRNGKeyArray
def init(self) -> None:
qkv_dim = self.qkv_dim
keys = jax.random.split(self.key, 2)
self.fc1 = Linear(in_features=qkv_dim, out_features=4 * qkv_dim, key=keys[0])
self.fc2 = Linear(in_features=4 * qkv_dim, out_features=qkv_dim, key=keys[1])
self.fc2.weights = jnp.zeros_like(self.fc2.weights)
def __call__(
self, x: jt.Float[jt.Array, "... qkv_dim"]
) -> jt.Float[jt.Array, "... qkv_dim"]:
x = self.fc1(x)
x = jax.nn.gelu(x)
x = self.fc2(x)
return x
class CausalSelfAttention(Module):
"""Causal Self-Attention layer with Rotary Positional Embeddings (RoPE)."""
qkv_dim: int
num_heads: int
seq_len: int
key: jt.PRNGKeyArray
def init(self) -> None:
self.mha = MultiHeadAttention(
qkv_dim=self.qkv_dim,
num_heads=self.num_heads,
key=self.key,
use_qk_norm=True,
)
self.mha.linear.weights = jnp.zeros_like(self.mha.linear.weights)
self.rope = RotaryPositionalEmbedding(
qkv_dim=self.qkv_dim // self.num_heads, seq_len=self.seq_len
)
def _make_causal_mask(self, seq_len: int) -> jt.Float[jt.Array, "seq_len seq_len"]:
return jnp.tril(jnp.ones((seq_len, seq_len)))
def __call__(
self, x: jt.Float[jt.Array, "... seq_len qkv_dim"]
) -> jt.Float[jt.Array, "... seq_len qkv_dim"]:
x = self.mha(
Q=x, K=x, V=x, mask=self._make_causal_mask(x.shape[-2]), rope=self.rope
)
return x
class DecoderBlock(Module):
qkv_dim: int
num_heads: int
seq_len: int
key: jt.PRNGKeyArray
def init(self) -> None:
keys = jax.random.split(self.key, 2)
self.causal_attn = CausalSelfAttention(
qkv_dim=self.qkv_dim,
num_heads=self.num_heads,
seq_len=self.seq_len,
key=keys[0],
)
self.mlp = MLP(qkv_dim=self.qkv_dim, key=keys[1])
def __call__(
self, x: jt.Float[jt.Array, "... seq_len qkv_dim"]
) -> jt.Float[jt.Array, "... seq_len qkv_dim"]:
x = x + self.causal_attn(rms_norm(x))
x = x + self.mlp(rms_norm(x))
return x
class DecoderOnlyTransformer(Module):
"""A simple decoder-only transformer model.
Takes in a sequence of tokens, embeds them using a learned embedding layer,
applies causal self-attention with ROPE embeddings, and outputs the logits
for the next token prediction.
"""
key: jt.PRNGKeyArray
config: ModelConfig
def init(self) -> None:
config = self.config
keys = jax.random.split(self.key, config.num_layers + 2)
self.embedding_layer = Embedding(
vocab_size=config.vocab_size, embedding_dim=config.qkv_dim, key=keys[0]
)
self.blocks = []
for ix in range(config.num_layers):
self.blocks.append(
DecoderBlock(
qkv_dim=config.qkv_dim,
num_heads=config.num_heads,
seq_len=config.seq_len,
key=keys[ix + 1],
)
)
self.output_layer = Linear(
in_features=config.qkv_dim, out_features=config.vocab_size, key=keys[-1]
)
def __call__(
self, x: jt.Float[jt.Array, "batch_size seq_len"]
) -> jt.Float[jt.Array, "batch_size seq_len vocab_size"]:
"""Forward pass for the decoder-only transformer model.
Args:
x: Input token ids of shape [batch_size, seq_len]. Note that the sequence
length should match the configuration `seq_len` as this is used for
positional embeddings.
Returns:
Unnormalized logits over the vocabulary of shape [batch_size, seq_len, vocab_size]
"""
x = self.embedding_layer(x)
x = rms_norm(x)
for block in self.blocks:
x = block(x)
x = rms_norm(x)
x = self.output_layer(x)
x = self.config.logit_softcap * jnp.tanh(x / self.config.logit_softcap)
return x
def loss_fn(
logits: jt.Float[jt.Array, "batch_size seq_len vocab_size"],
targets: jt.Int[jt.Array, "batch_size seq_len"],
mask: jt.Int[jt.Array, "batch_size seq_len"],
) -> jt.Float[jt.Array, ""]:
"""Computes the cross-entropy loss between logits and targets.
Args:
logits: Logits of shape [batch_size, seq_len, vocab_size].
targets: Target token ids of shape [batch_size, seq_len].
Returns:
Scalar loss value.
"""
targets_one_hot = jax.nn.one_hot(targets, num_classes=logits.shape[-1])
log_probs = jax.nn.log_softmax(logits, axis=-1)
per_token_loss = -jnp.sum(targets_one_hot * log_probs, axis=-1)
masked_loss = mask * per_token_loss
total_loss = jnp.sum(masked_loss) / jnp.sum(mask)
return (total_loss, per_token_loss)
def train_step(
model: DecoderOnlyTransformer,
tokens: jt.Int[jt.Array, "batch_size seq_len"],
mask: jt.Int[jt.Array, "batch_size seq_len"],
targets: jt.Int[jt.Array, "batch_size seq_len"],
) -> jt.Float[jt.Array, ""]:
"""A single training step for the model.
Args:
model: The DecoderOnlyTransformer model.
batch: Input token ids of shape [batch_size, seq_len].
labels: Target token ids of shape [batch_size, seq_len].
Returns:
Loss value for the batch.
"""
logits = model(tokens)
total_loss, per_token_loss = loss_fn(logits, targets, mask)
return (total_loss, {"logits": logits, "per_token_loss": per_token_loss})
# ============================================================
# toylib_projects.tinystories.logger - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/logger.py
# ============================================================
class Logger(abc.ABC):
"""Interface for logging training metrics."""
def __init__(self, config_dict: dict, *args, **kwargs) -> None:
self.config_dict = config_dict
@abc.abstractmethod
def log(self, step: int, metrics: dict) -> None:
"""Log the given metrics at the specified step."""
pass
@abc.abstractmethod
def close(self) -> None:
"""Close any resources held by the logger."""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
class FileLogger(Logger):
"""Logger implementation that logs metrics to a local file."""
def __init__(self, config_dict: dict, output_path: str, *args, **kwargs) -> None:
self.config_dict = config_dict
self.file_ptr = open(output_path, "w")
self.file_ptr.write("\n")
def log(self, step: int, metrics: dict) -> None:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.file_ptr.write(f"[{timestamp}] Step {step}: {metrics}\n")
def close(self) -> None:
self.file_ptr.close()
class StdoutLogger(Logger):
"""Logger implementation that logs metrics to standard output."""
def __init__(self, config_dict: dict, *args, **kwargs) -> None:
self.config_dict = config_dict
def log(self, step: int, metrics: dict) -> None:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] Step {step}: {metrics}")
def close(self) -> None:
pass
# ============================================================
# toylib_projects.tinystories.experiment - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/experiment.py
# ============================================================
"""Basic types for the training loop and configurations."""
@dataclasses.dataclass
class CheckpointConfig:
save_interval_steps: int = 5000
max_to_keep: typing.Optional[int] = 10
checkpoint_dir: str = "/tmp/checkpoints"
@dataclasses.dataclass
class TrainingConfig:
learning_rate: float = 0.001
max_steps: int = 100000
@dataclasses.dataclass
class Task:
name: str
dataset: BatchedTokenizedDataset
@dataclasses.dataclass(kw_only=True)
class LoggerConfig:
logger_cls: Logger = FileLogger
log_dir: str = "/tmp/train_logs.txt"
def _serlialize_dataclass_config(config: dataclasses.dataclass) -> dict:
result = dataclasses.asdict(config)
for k, v in result.items():
if dataclasses.is_dataclass(v):
result[k] = _serlialize_dataclass_config(v)
return result
@dataclasses.dataclass
class Experiment:
"""Base Experiment class."""
train_task: Task
eval_tasks: list[Task] = dataclasses.field(default_factory=list)
model_config: ModelConfig = dataclasses.field(default_factory=ModelConfig)
training_config: TrainingConfig = dataclasses.field(default_factory=TrainingConfig)
checkpoint_config: CheckpointConfig = dataclasses.field(
default_factory=CheckpointConfig
)
logger_config: LoggerConfig = dataclasses.field(default_factory=LoggerConfig)
forward_fn: ... = dataclasses.field(default_factory=lambda: train_step)
jit_train_fn: bool = True
def __post_init__(self):
self.logger_obj = self.logger_config.logger_cls(
config_dict=_serlialize_dataclass_config(self),
output_path=self.logger_config.log_dir,
)
self.optimizer = optax.adam(learning_rate=self.training_config.learning_rate)
self.opt_state = None
self.model = None
self.ckpt_manager = ocp.CheckpointManager(
self.checkpoint_config.checkpoint_dir,
checkpointers={
"model": ocp.StandardCheckpointer(),
"opt_state": ocp.StandardCheckpointer(),
},
options=ocp.CheckpointManagerOptions(
max_to_keep=self.checkpoint_config.max_to_keep
),
)
def train_step(model, opt_state, batch):
inputs, targets = (batch["inputs"], batch["targets"])
mask = jax.numpy.ones_like(inputs)
with jax.profiler.TraceAnnotation("value_and_grad"):
(loss_val, _), grads = jax.value_and_grad(
self.forward_fn, has_aux=True
)(model, inputs, mask, targets)
with jax.profiler.TraceAnnotation("optimizer_update"):
updates, opt_state = self.optimizer.update(grads, opt_state)
model = optax.apply_updates(model, updates)
return (model, opt_state, loss_val)
if self.jit_train_fn:
self.train_step_fn = jax.jit(train_step)
else:
self.train_step_fn = train_step
def init_state(self):
self.model = DecoderOnlyTransformer(
config=self.model_config, key=jax.random.PRNGKey(0)
)
self.opt_state = self.optimizer.init(self.model)
self.step = 0
def _assert_initialized(self) -> bool:
initialized = self.model is not None and self.opt_state is not None
assert initialized, "Experiment state not initialized. Call init_state() first."
def save_checkpoint(self):
self._assert_initialized()
self.ckpt_manager.save(
self.step,
args=ocp.args.Composite(
model=ocp.args.StandardSave(self.model),
opt_state=ocp.args.StandardSave(self.opt_state),
),
)
self.ckpt_manager.wait_until_finished()
def restore_checkpoint(self, step: int):
self._assert_initialized()
restored = self.ckpt_manager.restore(
step,
args=ocp.args.Composite(
model=ocp.args.StandardRestore(self.model),
opt_state=ocp.args.StandardRestore(self.opt_state),
),
)
self.model = restored["model"]
self.opt_state = restored["opt_state"]
self.step = step
def log_metrics(self, step: int, loss_val: float):
metrics = {
"train/loss": float(loss_val),
"train/learning_rate": self.training_config.learning_rate,
}
self.logger_obj.log(step=step, metrics=metrics)
def inner_loop(self, batch: dict):
self._assert_initialized()
self.model, self.opt_state, loss_val = self.train_step_fn(
self.model, self.opt_state, batch
)
loss_val.block_until_ready()
self.log_metrics(self.step, loss_val)
self.step += 1
def outer_loop(self):
finished = self.step >= self.training_config.max_steps
while True:
epoch_start_step = self.step
for batch in self.train_task.dataset:
with jax.profiler.StepTraceAnnotation("inner_loop", step_num=self.step):
self.inner_loop(batch)
if self.step % self.checkpoint_config.save_interval_steps == 0:
self.save_checkpoint()
if self.step >= self.training_config.max_steps:
finished = True
break
if finished:
break
if self.step == epoch_start_step:
raise ValueError(f"Dataset for task {self.train_task.name} is empty.")
def cleanup(self):
self.logger_obj.close()
self.ckpt_manager.close()
# ============================================================
# None - ../tinystories/train.py
# ============================================================
def get_model_config(
depth: int, seq_len: int = 1024, vocab_size: int = 50257
) -> ModelConfig:
num_layers = depth
model_dim = depth * 64
num_heads = max(1, (model_dim + 127) // 128)
num_kv_heads = num_heads
print(f"num_layers: {num_layers}")
print(f"model_dim: {model_dim}")
print(f"num_heads: {num_heads}")
print(f"num_kv_heads: {num_kv_heads}")
return ModelConfig(
num_layers=depth,
num_heads=num_heads,
qkv_dim=model_dim,
vocab_size=vocab_size,
seq_len=seq_len,
)
def main():
batch_size = 8
seq_len = 1024
depth = 12
vocab_size = 50257
dataset = BatchedTokenizedDatasetParquet(
dataset_path="/tmp/",
split="train",
batch_size=batch_size,
seq_len=seq_len,
tokenizer_batch_size=8,
)
train_task = Task(name="train", dataset=dataset)
exp = Experiment(
model_config=get_model_config(
depth=depth, seq_len=seq_len, vocab_size=vocab_size
),
training_config=TrainingConfig(learning_rate=0.001, max_steps=100000),
checkpoint_config=CheckpointConfig(
save_interval_steps=2500, max_to_keep=10, checkpoint_dir="/tmp/checkpoints"
),
train_task=train_task,
)
exp.init_state()
exp.outer_loop()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment