Skip to content

Instantly share code, notes, and snippets.

@abatilo
Last active December 29, 2025 01:28
Show Gist options
  • Select an option

  • Save abatilo/07065cd1d749d04b7d7bcb2b8c9ab885 to your computer and use it in GitHub Desktop.

Select an option

Save abatilo/07065cd1d749d04b7d7bcb2b8c9ab885 to your computer and use it in GitHub Desktop.
DeepSpeed + Hugging Face Transformers Training Script for Spin Tutorial
#!/usr/bin/env python3
"""
Hugging Face Transformers Distributed Training Example for Spin
Trains a GPT-2 model on WikiText-2 using distributed data parallel.
Designed to run on multiple nodes via Spin's SyncSet orchestration.
"""
import os
from datasets import load_dataset
from transformers import (
AutoConfig,
GPT2LMHeadModel,
AutoTokenizer,
Trainer,
TrainingArguments,
)
block_size = 128
def main():
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_rank = int(os.environ.get("RANK", 0))
if world_rank == 0:
print("Loading dataset...")
# Load and prepare data
raw_datasets = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.model_max_length = block_size
tokenizer.pad_token = tokenizer.eos_token
# Remove empty strings and tokenize
raw_datasets = raw_datasets.filter(lambda x: len(x["text"]) > 0)
def tokenize_function(examples):
tokenized = tokenizer(examples["text"], padding=True, truncation=True)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
remove_columns=raw_datasets["train"].column_names,
)
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
if world_rank == 0:
print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
model_config = AutoConfig.from_pretrained("gpt2")
model_config.n_positions = block_size
model = GPT2LMHeadModel(model_config)
if world_rank == 0:
total_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {total_params/10**6:.2f}M")
training_args = TrainingArguments(
output_dir="/tmp/model",
max_steps=1000,
fp16=True,
logging_steps=10,
save_total_limit=3,
deepspeed={
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"steps_per_print": "auto",
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto",
},
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
},
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
},
},
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
)
trainer.train()
print("Training complete!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment