Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created December 26, 2025 23:46
Show Gist options
  • Select an option

  • Save shunting314/4d279b4cf32112c8f54e817fd3610ae5 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/4d279b4cf32112c8f54e817fd3610ae5 to your computer and use it in GitHub Desktop.
diff --git a/run_train.sh b/run_train.sh
index 87558a78..0a256031 100755
--- a/run_train.sh
+++ b/run_train.sh
@@ -30,6 +30,6 @@ else
PYTORCH_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
- --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
+ --local-ranks-filter ${LOG_RANK} --role rank --tee 0 \
-m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@"
fi
diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py
index cafd58a5..172cc234 100644
--- a/torchtitan/models/llama3/model/model.py
+++ b/torchtitan/models/llama3/model/model.py
@@ -582,5 +582,6 @@ class Transformer(nn.Module, ModelProtocol):
# pyrefly: ignore [not-callable]
h = self.norm(h) if self.norm else h
# pyrefly: ignore [not-callable]
- output = self.output(h) if self.output else h
+ # output = self.output(h) if self.output else h
+ output = h
return output
diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml
index ef86d783..dc76c3a3 100644
--- a/torchtitan/models/llama3/train_configs/llama3_8b.toml
+++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml
@@ -52,7 +52,7 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[compile]
enable=false
-components = ["model", "loss"]
+components = ["model"]
[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
diff --git a/torchtitan/train.py b/torchtitan/train.py
index 8455e54e..8ccc3950 100644
--- a/torchtitan/train.py
+++ b/torchtitan/train.py
@@ -11,6 +11,7 @@ import os
import time
from datetime import timedelta
from typing import Any, Iterable
+import torch.distributed as dist
import torch
import torch.distributed.checkpoint.stateful
@@ -263,6 +264,13 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
self.model_parts = [model]
+ def unembed_and_loss(pred, loss):
+ # dist.breakpoint() # TODO
+ return self.loss_fn(model.output(pred), loss)
+
+ self.unembed_and_loss = torch.compile(unembed_and_loss)
+ # self.unembed_and_loss = unembed_and_loss
+
self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)
# initialize device memory monitor and get peak flops for MFU calculation
@@ -541,7 +549,9 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
- loss = self.loss_fn(pred, labels)
+ # dist.breakpoint()
+ # loss = self.loss_fn(pred, labels)
+ loss = self.unembed_and_loss(pred, labels)
# need to free pred before bwd to avoid peaking memory
del pred
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment