Last active
December 23, 2025 19:03
-
-
Save anujkhare/d999793164fa4162dd7cb8ffdacfa090 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ============================================================ | |
| # External Imports | |
| # ============================================================ | |
| from jax import numpy as jnp | |
| from transformers import AutoTokenizer | |
| import abc | |
| import dataclasses | |
| import datetime | |
| import einops | |
| import jax | |
| import jaxtyping as jt | |
| import math | |
| import numpy as np | |
| import optax | |
| import orbax.checkpoint as ocp | |
| import pathlib | |
| import pyarrow.parquet as pq | |
| import typing | |
| # ============================================================ | |
| # toylib_projects.tinystories.data - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/data.py | |
| # ============================================================ | |
| @dataclasses.dataclass | |
| class BatchedTokenizedDataset(abc.ABC): | |
| dataset_path: str = "karpathy/fineweb-edu-100b-shuffle" | |
| split: str = "train" | |
| tokenizer_name: str = "gpt2" | |
| seq_len: int = 2048 | |
| tokenizer_batch_size: int = 8 | |
| batch_size: int = 128 | |
| @abc.abstractmethod | |
| def _get_dataset_iterator(self) -> typing.Iterator: | |
| raise NotImplementedError | |
| def __post_init__(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) | |
| self.bos_token = self.tokenizer.bos_token_id | |
| self.token_buffer = [] | |
| self.dataset_iter = self._get_dataset_iterator() | |
| def __iter__(self): | |
| return self | |
| def __next__(self) -> jnp.ndarray: | |
| token_needed = self.batch_size * self.seq_len + 1 | |
| while len(self.token_buffer) < token_needed: | |
| input_batch = next(self.dataset_iter) | |
| texts = input_batch["text"] | |
| tokenized = self.tokenizer( | |
| texts, | |
| return_tensors=None, | |
| padding=False, | |
| truncation=False, | |
| max_length=None, | |
| )["input_ids"] | |
| for tokens in tokenized: | |
| self.token_buffer.append(self.bos_token) | |
| self.token_buffer.extend(tokens) | |
| tokens = self.token_buffer[:token_needed] | |
| self.token_buffer = self.token_buffer[token_needed:] | |
| inputs = jnp.array(tokens[:-1], dtype=jnp.uint16).reshape( | |
| self.batch_size, self.seq_len | |
| ) | |
| targets = jnp.array(tokens[1:], dtype=jnp.uint16).reshape( | |
| self.batch_size, self.seq_len | |
| ) | |
| return {"inputs": inputs, "targets": targets} | |
| class BatchedTokenizedDatasetParquet(BatchedTokenizedDataset): | |
| """Path is constructed as dataset_path/split/*.parquet""" | |
| def list_files(self): | |
| base_path = pathlib.Path(self.dataset_path) / self.split | |
| return list(base_path.glob("*.parquet")) | |
| def _get_dataset_iterator(self): | |
| for file_path in self.list_files(): | |
| pf = pq.ParquetFile(file_path) | |
| for row_group in range(pf.num_row_groups): | |
| rg = pf.read_row_group(row_group) | |
| yield {"text": rg.column("text").to_pylist()} | |
| # ============================================================ | |
| # toylib.nn.module - /Users/anuj/Desktop/code/toylib/toylib/nn/module.py | |
| # ============================================================ | |
| def _is_array(x: typing.Any) -> bool: | |
| return isinstance(x, (jax.Array, np.ndarray, np.generic)) or hasattr( | |
| x, "__jax_array__" | |
| ) | |
| def _is_random_key(x: str) -> bool: | |
| return x == "key" | |
| def _is_supported_container(x: typing.Any) -> bool: | |
| return isinstance(x, (list, tuple)) | |
| class Module(abc.ABC): | |
| """ | |
| Defines a base class to use for the neural network modules in toylib. | |
| Assumes that all jax arrays are leaf nodes that are trainable and | |
| everything else is a static param. Defines the flatten and unflatten methods | |
| to make the modules compatible with jax `jit` and `grad` functions. | |
| Refer https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees | |
| Inspired by equinox and the Custom PyTres and Initialization section in jax docs. | |
| """ | |
| def __init_subclass__(cls, **kwargs: typing.Any) -> None: | |
| super().__init_subclass__(**kwargs) | |
| cls = dataclasses.dataclass(cls, kw_only=True) | |
| cls = jax.tree_util.register_pytree_with_keys_class(cls) | |
| @abc.abstractmethod | |
| def init(self) -> None: | |
| """Initialize all the trainable parameters in the module.""" | |
| pass | |
| @abc.abstractmethod | |
| def __call__(self, *args, **kwargs) -> typing.Any: | |
| """Run a forward pass of the module.""" | |
| pass | |
| def _get_trainable_param_keys(self) -> list[str]: | |
| """Get the list of attribute names that are trainable parameters.""" | |
| param_keys = [] | |
| for k, v in self.__dict__.items(): | |
| if ( | |
| _is_array(v) | |
| and (not _is_random_key(k)) | |
| or isinstance(v, Module) | |
| or ( | |
| _is_supported_container(v) | |
| and all((isinstance(elem, Module) for elem in v)) | |
| ) | |
| ): | |
| param_keys.append(k) | |
| return param_keys | |
| def __post_init__(self) -> None: | |
| self.init() | |
| self._trainable_param_keys = self._get_trainable_param_keys() | |
| def tree_flatten_with_keys(self) -> tuple: | |
| params_with_keys = [] | |
| aux_data = dict() | |
| for k, v in self.__dict__.items(): | |
| if k not in self._trainable_param_keys: | |
| aux_data[k] = v | |
| for k in self._trainable_param_keys: | |
| v = self.__dict__[k] | |
| params_with_keys.append((jax.tree_util.GetAttrKey(k), v)) | |
| return (params_with_keys, aux_data) | |
| @classmethod | |
| def tree_unflatten(cls, static, dynamic) -> "Module": | |
| obj = object.__new__(cls) | |
| param_keys = static["_trainable_param_keys"] | |
| for k, v in zip(param_keys, dynamic): | |
| obj.__setattr__(k, v) | |
| for k, v in static.items(): | |
| obj.__setattr__(k, v) | |
| return obj | |
| # ============================================================ | |
| # toylib.nn.layers - /Users/anuj/Desktop/code/toylib/toylib/nn/layers.py | |
| # ============================================================ | |
| class Linear(Module): | |
| """Defines a simple feedforward layer: which is a linear transformation.""" | |
| in_features: int | |
| out_features: int | |
| use_bias: bool = False | |
| key: jt.PRNGKeyArray | |
| weights: jt.Float[jt.Array, "in_features out_features"] | None = None | |
| bias: typing.Optional[jt.Float[jt.Array, " out_features"]] | None = None | |
| def init(self) -> None: | |
| w_key = self.key | |
| in_features = self.in_features | |
| out_features = self.out_features | |
| std = min(1.0, math.sqrt(out_features / in_features)) / math.sqrt(in_features) | |
| self.weights = jax.random.normal(w_key, (in_features, out_features)) * std | |
| self.bias = jax.numpy.zeros((out_features,)) if self.use_bias else None | |
| def __call__( | |
| self, x: jt.Float[jt.Array, "... in_features"] | |
| ) -> jt.Float[jt.Array, "... out_features"]: | |
| x = jax.numpy.dot(x, self.weights) | |
| if self.use_bias: | |
| x = x + self.bias | |
| return x | |
| class Embedding(Module): | |
| """Defines an embedding layer that stores an embedding matrix for discrete tokens.""" | |
| vocab_size: int | |
| embedding_dim: int | |
| key: jt.PRNGKeyArray | |
| weights: jt.Float[jt.Array, "vocab_size embedding_dim"] | None = None | |
| def init(self) -> None: | |
| self.weights = jax.random.normal( | |
| self.key, (self.vocab_size, self.embedding_dim) | |
| ) | |
| def __call__( | |
| self, tokens: jt.Int[jt.Array, "... seq_len"] | |
| ) -> jt.Float[jt.Array, "... seq_len embedding_dim"]: | |
| return jax.numpy.take(self.weights, tokens, axis=0) | |
| def rms_norm(x: jt.Float[jt.Array, "... dim"]) -> jt.Float[jt.Array, "... dim"]: | |
| """Applies RMS Normalization over the last dimension of the input tensor. | |
| Args: | |
| x: Input tensor | |
| Returns: | |
| The RMS normalized tensor of the same shape as input x. | |
| """ | |
| rms = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + 1e-09) | |
| return x / rms | |
| # ============================================================ | |
| # toylib.nn.attention - /Users/anuj/Desktop/code/toylib/toylib/nn/attention.py | |
| # ============================================================ | |
| class RotaryPositionalEmbedding(Module): | |
| """Implements Rotary Positional Embeddings (RoPE) as described in https://arxiv.org/abs/2104.09864.""" | |
| seq_len: int = 1024 | |
| qkv_dim: int = 128 | |
| base: int = 100000 | |
| def init(self) -> None: | |
| positions = jnp.arange(0, self.seq_len) | |
| freqs = self.base ** (jnp.arange(0, self.qkv_dim, 2) / self.qkv_dim) | |
| self.gamma = einops.einsum(positions, 1.0 / freqs, "t, d -> t d") | |
| self.cos = jnp.cos(self.gamma) | |
| self.sin = jnp.sin(self.gamma) | |
| def __call__( | |
| self, x: jt.Float[jt.Array, "... seq_len qkv_dim"] | |
| ) -> jt.Float[jt.Array, "... seq_len qkv_dim"]: | |
| d = x.shape[-1] | |
| x1, x2 = (x[..., : d // 2], x[..., d // 2 :]) | |
| es_shape = "... t d, t d -> ... t d" | |
| y1 = einops.einsum(x1, self.cos, es_shape) + einops.einsum( | |
| x2, self.sin, es_shape | |
| ) | |
| y2 = -einops.einsum(x1, self.sin, es_shape) + einops.einsum( | |
| x2, self.cos, es_shape | |
| ) | |
| return jnp.concatenate([y1, y2], axis=-1) | |
| def scaled_dot_product_attention( | |
| q: jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| k: jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| v: jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| mask: typing.Optional[jt.Float[jt.Array, "... seq_len seq_len"]], | |
| ) -> tuple[ | |
| jt.Float[jt.Array, "... seq_len qkv_dim"], jt.Float[jt.Array, "... seq_len seq_len"] | |
| ]: | |
| """Compute scaled dot product attention. | |
| Given query (`q`), key (`k`), and value (`v`) tensors, this function first computes the | |
| attention weights as the softmax of the dot product of `q` and `k`, scaled by the square | |
| root of the dimension of the keys. If a mask is provided, it is applied to the attention | |
| logits before the softmax is computed. | |
| Finally, the attention weights are used to compute the weighted average of the given values. | |
| NOTE: the batch dimension is not explicitly handled in this function. | |
| Args: | |
| q: query tensor | |
| k: keys tensor | |
| v: values tensor | |
| mask: optional boolean mask to apply to the attention logits | |
| Returns: | |
| tuple of final values and attention weights | |
| """ | |
| d_k = q.shape[-1] | |
| assert q.shape[-1] == k.shape[-1], "q and k must have the same feature dimension" | |
| attention_logits = jnp.matmul(q, k.swapaxes(-1, -2)) / jnp.sqrt(d_k) | |
| if mask is not None: | |
| attention_logits = jnp.where(mask, attention_logits, -1000000000.0) | |
| attention_weights = jax.nn.softmax(attention_logits, axis=-1) | |
| values = jnp.matmul(attention_weights, v) | |
| return (values, attention_weights) | |
| class MultiHeadAttention(Module): | |
| """ | |
| The MultiHeadAttention defines `num_heads` attention heads. For the given input `Q`, `K`, `V` | |
| tensors, `num_head` linear projections of dim `qkv_dim / num_heads` are produced. | |
| An attention weight is then computed using the scaled dot product attention method. The | |
| weighted average of the values are then concatenated from the various heads to produce a | |
| single output value vector. A final linear layer is applied on top of this with non-linearity. | |
| """ | |
| qkv_dim: int | |
| num_heads: int | |
| use_qk_norm: bool = True | |
| key: jt.PRNGKeyArray | |
| def init(self) -> None: | |
| qkv_dim = self.qkv_dim | |
| keys = jax.random.split(self.key, 4) | |
| self.q_projection = Linear( | |
| in_features=qkv_dim, out_features=qkv_dim, use_bias=False, key=keys[0] | |
| ) | |
| self.k_projection = Linear( | |
| in_features=qkv_dim, out_features=qkv_dim, use_bias=False, key=keys[1] | |
| ) | |
| self.v_projection = Linear( | |
| in_features=qkv_dim, out_features=qkv_dim, use_bias=False, key=keys[2] | |
| ) | |
| self.linear = Linear(in_features=qkv_dim, out_features=qkv_dim, key=keys[3]) | |
| def __call__( | |
| self, | |
| Q: jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| K: jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| V: jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| mask: typing.Optional[jt.Float[jt.Array, "... seq_len seq_len"]] = None, | |
| *, | |
| rope: typing.Optional[RotaryPositionalEmbedding] = None, | |
| return_attention_weights: bool = False, | |
| ) -> typing.Union[ | |
| tuple[ | |
| jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| jt.Float[jt.Array, "... seq_len seq_len"], | |
| ], | |
| jt.Float[jt.Array, "... seq_len qkv_dim"], | |
| ]: | |
| Q = self.q_projection(Q) | |
| K = self.k_projection(K) | |
| V = self.v_projection(V) | |
| Q = einops.rearrange( | |
| Q, | |
| "... seq_len (num_heads head_dim) -> ... num_heads seq_len head_dim", | |
| num_heads=self.num_heads, | |
| ) | |
| K = einops.rearrange( | |
| K, | |
| "... seq_len (num_heads head_dim) -> ... num_heads seq_len head_dim", | |
| num_heads=self.num_heads, | |
| ) | |
| V = einops.rearrange( | |
| V, | |
| "... seq_len (num_heads head_dim) -> ... num_heads seq_len head_dim", | |
| num_heads=self.num_heads, | |
| ) | |
| if mask is not None: | |
| mask = einops.rearrange( | |
| mask, "... seq_len1 seq_len2 -> ... 1 seq_len1 seq_len2" | |
| ) | |
| if rope is not None: | |
| Q = rope(Q) | |
| K = rope(K) | |
| if self.use_qk_norm: | |
| Q = rms_norm(Q) | |
| K = rms_norm(K) | |
| values, attention_weights = scaled_dot_product_attention( | |
| q=Q, k=K, v=V, mask=mask | |
| ) | |
| values = einops.rearrange( | |
| values, "... num_heads seq_len d -> ... seq_len (num_heads d)" | |
| ) | |
| values = self.linear(values) | |
| if return_attention_weights: | |
| return (values, attention_weights) | |
| return values | |
| # ============================================================ | |
| # toylib_projects.tinystories.decoder_only_model - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/decoder_only_model.py | |
| # ============================================================ | |
| @dataclasses.dataclass | |
| class ModelConfig: | |
| """Configuration for the DecoderOnlyTransformer model.""" | |
| num_layers: int = 2 | |
| num_heads: int = 8 | |
| qkv_dim: int = 256 | |
| vocab_size: int = 50257 | |
| seq_len: int = 512 | |
| logit_softcap: float = 15.0 | |
| class MLP(Module): | |
| """A simple feedforward MLP with one hidden layer.""" | |
| qkv_dim: int | |
| key: jt.PRNGKeyArray | |
| def init(self) -> None: | |
| qkv_dim = self.qkv_dim | |
| keys = jax.random.split(self.key, 2) | |
| self.fc1 = Linear(in_features=qkv_dim, out_features=4 * qkv_dim, key=keys[0]) | |
| self.fc2 = Linear(in_features=4 * qkv_dim, out_features=qkv_dim, key=keys[1]) | |
| self.fc2.weights = jnp.zeros_like(self.fc2.weights) | |
| def __call__( | |
| self, x: jt.Float[jt.Array, "... qkv_dim"] | |
| ) -> jt.Float[jt.Array, "... qkv_dim"]: | |
| x = self.fc1(x) | |
| x = jax.nn.gelu(x) | |
| x = self.fc2(x) | |
| return x | |
| class CausalSelfAttention(Module): | |
| """Causal Self-Attention layer with Rotary Positional Embeddings (RoPE).""" | |
| qkv_dim: int | |
| num_heads: int | |
| seq_len: int | |
| key: jt.PRNGKeyArray | |
| def init(self) -> None: | |
| self.mha = MultiHeadAttention( | |
| qkv_dim=self.qkv_dim, | |
| num_heads=self.num_heads, | |
| key=self.key, | |
| use_qk_norm=True, | |
| ) | |
| self.mha.linear.weights = jnp.zeros_like(self.mha.linear.weights) | |
| self.rope = RotaryPositionalEmbedding( | |
| qkv_dim=self.qkv_dim // self.num_heads, seq_len=self.seq_len | |
| ) | |
| def _make_causal_mask(self, seq_len: int) -> jt.Float[jt.Array, "seq_len seq_len"]: | |
| return jnp.tril(jnp.ones((seq_len, seq_len))) | |
| def __call__( | |
| self, x: jt.Float[jt.Array, "... seq_len qkv_dim"] | |
| ) -> jt.Float[jt.Array, "... seq_len qkv_dim"]: | |
| x = self.mha( | |
| Q=x, K=x, V=x, mask=self._make_causal_mask(x.shape[-2]), rope=self.rope | |
| ) | |
| return x | |
| class DecoderBlock(Module): | |
| qkv_dim: int | |
| num_heads: int | |
| seq_len: int | |
| key: jt.PRNGKeyArray | |
| def init(self) -> None: | |
| keys = jax.random.split(self.key, 2) | |
| self.causal_attn = CausalSelfAttention( | |
| qkv_dim=self.qkv_dim, | |
| num_heads=self.num_heads, | |
| seq_len=self.seq_len, | |
| key=keys[0], | |
| ) | |
| self.mlp = MLP(qkv_dim=self.qkv_dim, key=keys[1]) | |
| def __call__( | |
| self, x: jt.Float[jt.Array, "... seq_len qkv_dim"] | |
| ) -> jt.Float[jt.Array, "... seq_len qkv_dim"]: | |
| x = x + self.causal_attn(rms_norm(x)) | |
| x = x + self.mlp(rms_norm(x)) | |
| return x | |
| class DecoderOnlyTransformer(Module): | |
| """A simple decoder-only transformer model. | |
| Takes in a sequence of tokens, embeds them using a learned embedding layer, | |
| applies causal self-attention with ROPE embeddings, and outputs the logits | |
| for the next token prediction. | |
| """ | |
| key: jt.PRNGKeyArray | |
| config: ModelConfig | |
| def init(self) -> None: | |
| config = self.config | |
| keys = jax.random.split(self.key, config.num_layers + 2) | |
| self.embedding_layer = Embedding( | |
| vocab_size=config.vocab_size, embedding_dim=config.qkv_dim, key=keys[0] | |
| ) | |
| self.blocks = [] | |
| for ix in range(config.num_layers): | |
| self.blocks.append( | |
| DecoderBlock( | |
| qkv_dim=config.qkv_dim, | |
| num_heads=config.num_heads, | |
| seq_len=config.seq_len, | |
| key=keys[ix + 1], | |
| ) | |
| ) | |
| self.output_layer = Linear( | |
| in_features=config.qkv_dim, out_features=config.vocab_size, key=keys[-1] | |
| ) | |
| def __call__( | |
| self, x: jt.Float[jt.Array, "batch_size seq_len"] | |
| ) -> jt.Float[jt.Array, "batch_size seq_len vocab_size"]: | |
| """Forward pass for the decoder-only transformer model. | |
| Args: | |
| x: Input token ids of shape [batch_size, seq_len]. Note that the sequence | |
| length should match the configuration `seq_len` as this is used for | |
| positional embeddings. | |
| Returns: | |
| Unnormalized logits over the vocabulary of shape [batch_size, seq_len, vocab_size] | |
| """ | |
| x = self.embedding_layer(x) | |
| x = rms_norm(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = rms_norm(x) | |
| x = self.output_layer(x) | |
| x = self.config.logit_softcap * jnp.tanh(x / self.config.logit_softcap) | |
| return x | |
| def loss_fn( | |
| logits: jt.Float[jt.Array, "batch_size seq_len vocab_size"], | |
| targets: jt.Int[jt.Array, "batch_size seq_len"], | |
| mask: jt.Int[jt.Array, "batch_size seq_len"], | |
| ) -> jt.Float[jt.Array, ""]: | |
| """Computes the cross-entropy loss between logits and targets. | |
| Args: | |
| logits: Logits of shape [batch_size, seq_len, vocab_size]. | |
| targets: Target token ids of shape [batch_size, seq_len]. | |
| Returns: | |
| Scalar loss value. | |
| """ | |
| targets_one_hot = jax.nn.one_hot(targets, num_classes=logits.shape[-1]) | |
| log_probs = jax.nn.log_softmax(logits, axis=-1) | |
| per_token_loss = -jnp.sum(targets_one_hot * log_probs, axis=-1) | |
| masked_loss = mask * per_token_loss | |
| total_loss = jnp.sum(masked_loss) / jnp.sum(mask) | |
| return (total_loss, per_token_loss) | |
| def train_step( | |
| model: DecoderOnlyTransformer, | |
| tokens: jt.Int[jt.Array, "batch_size seq_len"], | |
| mask: jt.Int[jt.Array, "batch_size seq_len"], | |
| targets: jt.Int[jt.Array, "batch_size seq_len"], | |
| ) -> jt.Float[jt.Array, ""]: | |
| """A single training step for the model. | |
| Args: | |
| model: The DecoderOnlyTransformer model. | |
| batch: Input token ids of shape [batch_size, seq_len]. | |
| labels: Target token ids of shape [batch_size, seq_len]. | |
| Returns: | |
| Loss value for the batch. | |
| """ | |
| logits = model(tokens) | |
| total_loss, per_token_loss = loss_fn(logits, targets, mask) | |
| return (total_loss, {"logits": logits, "per_token_loss": per_token_loss}) | |
| # ============================================================ | |
| # toylib_projects.tinystories.logger - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/logger.py | |
| # ============================================================ | |
| class Logger(abc.ABC): | |
| """Interface for logging training metrics.""" | |
| def __init__(self, config_dict: dict, *args, **kwargs) -> None: | |
| self.config_dict = config_dict | |
| @abc.abstractmethod | |
| def log(self, step: int, metrics: dict) -> None: | |
| """Log the given metrics at the specified step.""" | |
| pass | |
| @abc.abstractmethod | |
| def close(self) -> None: | |
| """Close any resources held by the logger.""" | |
| pass | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| self.close() | |
| class FileLogger(Logger): | |
| """Logger implementation that logs metrics to a local file.""" | |
| def __init__(self, config_dict: dict, output_path: str, *args, **kwargs) -> None: | |
| self.config_dict = config_dict | |
| self.file_ptr = open(output_path, "w") | |
| self.file_ptr.write("\n") | |
| def log(self, step: int, metrics: dict) -> None: | |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| self.file_ptr.write(f"[{timestamp}] Step {step}: {metrics}\n") | |
| def close(self) -> None: | |
| self.file_ptr.close() | |
| class StdoutLogger(Logger): | |
| """Logger implementation that logs metrics to standard output.""" | |
| def __init__(self, config_dict: dict, *args, **kwargs) -> None: | |
| self.config_dict = config_dict | |
| def log(self, step: int, metrics: dict) -> None: | |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| print(f"[{timestamp}] Step {step}: {metrics}") | |
| def close(self) -> None: | |
| pass | |
| # ============================================================ | |
| # toylib_projects.tinystories.experiment - /Users/anuj/Desktop/code/toylib/toylib_projects/tinystories/experiment.py | |
| # ============================================================ | |
| """Basic types for the training loop and configurations.""" | |
| @dataclasses.dataclass | |
| class CheckpointConfig: | |
| save_interval_steps: int = 5000 | |
| max_to_keep: typing.Optional[int] = 10 | |
| checkpoint_dir: str = "/tmp/checkpoints" | |
| @dataclasses.dataclass | |
| class TrainingConfig: | |
| learning_rate: float = 0.001 | |
| max_steps: int = 100000 | |
| @dataclasses.dataclass | |
| class Task: | |
| name: str | |
| dataset: BatchedTokenizedDataset | |
| @dataclasses.dataclass(kw_only=True) | |
| class LoggerConfig: | |
| logger_cls: Logger = FileLogger | |
| log_dir: str = "/tmp/train_logs.txt" | |
| def _serlialize_dataclass_config(config: dataclasses.dataclass) -> dict: | |
| result = dataclasses.asdict(config) | |
| for k, v in result.items(): | |
| if dataclasses.is_dataclass(v): | |
| result[k] = _serlialize_dataclass_config(v) | |
| return result | |
| @dataclasses.dataclass | |
| class Experiment: | |
| """Base Experiment class.""" | |
| train_task: Task | |
| eval_tasks: list[Task] = dataclasses.field(default_factory=list) | |
| model_config: ModelConfig = dataclasses.field(default_factory=ModelConfig) | |
| training_config: TrainingConfig = dataclasses.field(default_factory=TrainingConfig) | |
| checkpoint_config: CheckpointConfig = dataclasses.field( | |
| default_factory=CheckpointConfig | |
| ) | |
| logger_config: LoggerConfig = dataclasses.field(default_factory=LoggerConfig) | |
| forward_fn: ... = dataclasses.field(default_factory=lambda: train_step) | |
| jit_train_fn: bool = True | |
| def __post_init__(self): | |
| self.logger_obj = self.logger_config.logger_cls( | |
| config_dict=_serlialize_dataclass_config(self), | |
| output_path=self.logger_config.log_dir, | |
| ) | |
| self.optimizer = optax.adam(learning_rate=self.training_config.learning_rate) | |
| self.opt_state = None | |
| self.model = None | |
| self.ckpt_manager = ocp.CheckpointManager( | |
| self.checkpoint_config.checkpoint_dir, | |
| checkpointers={ | |
| "model": ocp.StandardCheckpointer(), | |
| "opt_state": ocp.StandardCheckpointer(), | |
| }, | |
| options=ocp.CheckpointManagerOptions( | |
| max_to_keep=self.checkpoint_config.max_to_keep | |
| ), | |
| ) | |
| def train_step(model, opt_state, batch): | |
| inputs, targets = (batch["inputs"], batch["targets"]) | |
| mask = jax.numpy.ones_like(inputs) | |
| with jax.profiler.TraceAnnotation("value_and_grad"): | |
| (loss_val, _), grads = jax.value_and_grad( | |
| self.forward_fn, has_aux=True | |
| )(model, inputs, mask, targets) | |
| with jax.profiler.TraceAnnotation("optimizer_update"): | |
| updates, opt_state = self.optimizer.update(grads, opt_state) | |
| model = optax.apply_updates(model, updates) | |
| return (model, opt_state, loss_val) | |
| if self.jit_train_fn: | |
| self.train_step_fn = jax.jit(train_step) | |
| else: | |
| self.train_step_fn = train_step | |
| def init_state(self): | |
| self.model = DecoderOnlyTransformer( | |
| config=self.model_config, key=jax.random.PRNGKey(0) | |
| ) | |
| self.opt_state = self.optimizer.init(self.model) | |
| self.step = 0 | |
| def _assert_initialized(self) -> bool: | |
| initialized = self.model is not None and self.opt_state is not None | |
| assert initialized, "Experiment state not initialized. Call init_state() first." | |
| def save_checkpoint(self): | |
| self._assert_initialized() | |
| self.ckpt_manager.save( | |
| self.step, | |
| args=ocp.args.Composite( | |
| model=ocp.args.StandardSave(self.model), | |
| opt_state=ocp.args.StandardSave(self.opt_state), | |
| ), | |
| ) | |
| self.ckpt_manager.wait_until_finished() | |
| def restore_checkpoint(self, step: int): | |
| self._assert_initialized() | |
| restored = self.ckpt_manager.restore( | |
| step, | |
| args=ocp.args.Composite( | |
| model=ocp.args.StandardRestore(self.model), | |
| opt_state=ocp.args.StandardRestore(self.opt_state), | |
| ), | |
| ) | |
| self.model = restored["model"] | |
| self.opt_state = restored["opt_state"] | |
| self.step = step | |
| def log_metrics(self, step: int, loss_val: float): | |
| metrics = { | |
| "train/loss": float(loss_val), | |
| "train/learning_rate": self.training_config.learning_rate, | |
| } | |
| self.logger_obj.log(step=step, metrics=metrics) | |
| def inner_loop(self, batch: dict): | |
| self._assert_initialized() | |
| self.model, self.opt_state, loss_val = self.train_step_fn( | |
| self.model, self.opt_state, batch | |
| ) | |
| loss_val.block_until_ready() | |
| self.log_metrics(self.step, loss_val) | |
| self.step += 1 | |
| def outer_loop(self): | |
| finished = self.step >= self.training_config.max_steps | |
| while True: | |
| epoch_start_step = self.step | |
| for batch in self.train_task.dataset: | |
| with jax.profiler.StepTraceAnnotation("inner_loop", step_num=self.step): | |
| self.inner_loop(batch) | |
| if self.step % self.checkpoint_config.save_interval_steps == 0: | |
| self.save_checkpoint() | |
| if self.step >= self.training_config.max_steps: | |
| finished = True | |
| break | |
| if finished: | |
| break | |
| if self.step == epoch_start_step: | |
| raise ValueError(f"Dataset for task {self.train_task.name} is empty.") | |
| def cleanup(self): | |
| self.logger_obj.close() | |
| self.ckpt_manager.close() | |
| # ============================================================ | |
| # None - ../tinystories/train.py | |
| # ============================================================ | |
| def get_model_config( | |
| depth: int, seq_len: int = 1024, vocab_size: int = 50257 | |
| ) -> ModelConfig: | |
| num_layers = depth | |
| model_dim = depth * 64 | |
| num_heads = max(1, (model_dim + 127) // 128) | |
| num_kv_heads = num_heads | |
| print(f"num_layers: {num_layers}") | |
| print(f"model_dim: {model_dim}") | |
| print(f"num_heads: {num_heads}") | |
| print(f"num_kv_heads: {num_kv_heads}") | |
| return ModelConfig( | |
| num_layers=depth, | |
| num_heads=num_heads, | |
| qkv_dim=model_dim, | |
| vocab_size=vocab_size, | |
| seq_len=seq_len, | |
| ) | |
| def main(): | |
| batch_size = 8 | |
| seq_len = 1024 | |
| depth = 12 | |
| vocab_size = 50257 | |
| dataset = BatchedTokenizedDatasetParquet( | |
| dataset_path="/tmp/", | |
| split="train", | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| tokenizer_batch_size=8, | |
| ) | |
| train_task = Task(name="train", dataset=dataset) | |
| exp = Experiment( | |
| model_config=get_model_config( | |
| depth=depth, seq_len=seq_len, vocab_size=vocab_size | |
| ), | |
| training_config=TrainingConfig(learning_rate=0.001, max_steps=100000), | |
| checkpoint_config=CheckpointConfig( | |
| save_interval_steps=2500, max_to_keep=10, checkpoint_dir="/tmp/checkpoints" | |
| ), | |
| train_task=train_task, | |
| ) | |
| exp.init_state() | |
| exp.outer_loop() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment