-
-
Save shunting314/4d279b4cf32112c8f54e817fd3610ae5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/run_train.sh b/run_train.sh | |
| index 87558a78..0a256031 100755 | |
| --- a/run_train.sh | |
| +++ b/run_train.sh | |
| @@ -30,6 +30,6 @@ else | |
| PYTORCH_ALLOC_CONF="expandable_segments:True" \ | |
| TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ | |
| torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ | |
| - --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ | |
| + --local-ranks-filter ${LOG_RANK} --role rank --tee 0 \ | |
| -m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@" | |
| fi | |
| diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py | |
| index cafd58a5..172cc234 100644 | |
| --- a/torchtitan/models/llama3/model/model.py | |
| +++ b/torchtitan/models/llama3/model/model.py | |
| @@ -582,5 +582,6 @@ class Transformer(nn.Module, ModelProtocol): | |
| # pyrefly: ignore [not-callable] | |
| h = self.norm(h) if self.norm else h | |
| # pyrefly: ignore [not-callable] | |
| - output = self.output(h) if self.output else h | |
| + # output = self.output(h) if self.output else h | |
| + output = h | |
| return output | |
| diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml | |
| index ef86d783..dc76c3a3 100644 | |
| --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml | |
| +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml | |
| @@ -52,7 +52,7 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] | |
| [compile] | |
| enable=false | |
| -components = ["model", "loss"] | |
| +components = ["model"] | |
| [activation_checkpoint] | |
| mode = "selective" # ["none", "selective", "full"] | |
| diff --git a/torchtitan/train.py b/torchtitan/train.py | |
| index 8455e54e..8ccc3950 100644 | |
| --- a/torchtitan/train.py | |
| +++ b/torchtitan/train.py | |
| @@ -11,6 +11,7 @@ import os | |
| import time | |
| from datetime import timedelta | |
| from typing import Any, Iterable | |
| +import torch.distributed as dist | |
| import torch | |
| import torch.distributed.checkpoint.stateful | |
| @@ -263,6 +264,13 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): | |
| self.model_parts = [model] | |
| + def unembed_and_loss(pred, loss): | |
| + # dist.breakpoint() # TODO | |
| + return self.loss_fn(model.output(pred), loss) | |
| + | |
| + self.unembed_and_loss = torch.compile(unembed_and_loss) | |
| + # self.unembed_and_loss = unembed_and_loss | |
| + | |
| self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) | |
| # initialize device memory monitor and get peak flops for MFU calculation | |
| @@ -541,7 +549,9 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): | |
| assert len(model_parts) == 1 | |
| with self.maybe_enable_amp: | |
| pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) | |
| - loss = self.loss_fn(pred, labels) | |
| + # dist.breakpoint() | |
| + # loss = self.loss_fn(pred, labels) | |
| + loss = self.unembed_and_loss(pred, labels) | |
| # need to free pred before bwd to avoid peaking memory | |
| del pred | |
| loss.backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment