Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created February 9, 2026 20:31
Show Gist options
  • Select an option

  • Save HDCharles/282950166fd0c95a7a2594fe922bcb53 to your computer and use it in GitHub Desktop.

Select an option

Save HDCharles/282950166fd0c95a7a2594fe922bcb53 to your computer and use it in GitHub Desktop.
from contextlib import contextmanager
import torch
from compressed_tensors.offload import offload_model
from compressed_tensors.offload.dispatch import remove_dispatch
from loguru import logger
import torch.distributed as dist
import inspect
import os
#### THIS STUFF WILL GO IN CT
def is_ddp() -> bool:
return torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
def init_dist():
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size,
device_id=device,
)
dist.barrier()
def convert_to_ct_offload(model, case, local_rank):
if case=="cuda":
remove_dispatch(model)
if case == "cpu":
if any(hasattr(m, '_hf_hook') for m in model.modules()):
raise NotImplementedError(
"Detected that model didn't fit entirely in ram and accelerate used "
"disk offloading, need CT Distributed Disk Offloading to support this"
)
remove_dispatch(model)
onload_device = torch.device(f"cuda:{local_rank}")
offload_device = torch.device("cpu")
offload_model(model, onload_device, offload_device)
return model
def patch_from_pretrained(cls, local_rank):
cls.from_pretrained_orig = cls.from_pretrained
def patched_from_pretrained(*args, **kwargs):
### OVERWRITE DEVICE MAP TO HANDLE DISTRIBUTED CASE
device_map = kwargs.get("device_map")
case=None
if device_map == "cuda":
kwargs["device_map"]=local_rank
case="cuda"
elif device_map == "cpu":
# we only want to load into cpu once
kwargs["device_map"]="cpu" if local_rank == 0 else "meta"
case="cpu"
elif device_map is None:
logger.warning("No device_map given to from_pretrained, defaulting to cpu")
kwargs["device_map"]="cpu" if local_rank == 0 else "meta"
case="cpu"
elif device_map == "disk":
raise NotImplementedError(f"device_map == {device_map} is not implemented, use cpu or cuda")
elif device_map == "auto":
raise NotImplementedError(f"device_map == {device_map} is not implemented, use cpu or cuda")
else:
raise NotImplementedError(f"device_map == {device_map} is not implemented, use cpu or cuda")
### LOAD WITH ACCELERATE + CORRECTED DEVICE MAP
model = cls.from_pretrained_orig(*args, **kwargs)
### CONVERT FROM ACCELERATE TO OUR OFFLOADING TOOL
model = convert_to_ct_offload(model, case, local_rank)
### PATCH SAVE_PRETRAINED SO IT WiLL WORK WITH CT OFFLOAD
# model = patch_save_pretrained(model)
return model
cls.from_pretrained = patched_from_pretrained
return cls
@contextmanager
def ct_offload():
if not is_ddp():
init_dist()
### Finds the correct frame with imports to patch
frame = inspect.currentframe()
while frame:
# Skip frames from contextlib module
if 'contextlib' not in frame.f_code.co_filename:
caller_globals = frame.f_globals
break
frame = frame.f_back
else:
raise RuntimeError("Could not find caller frame")
local_rank = dist.get_rank()
# wrap from_pretrained
# to swap accelerate offloading for CT offloading
# wrap save_pretrained
# to work with CT offloading
patched = []
for _, load_cls in caller_globals.items():
if (
hasattr(load_cls, 'from_pretrained') and
hasattr(load_cls, '__module__') and
'transformers' in load_cls.__module__
):
patched.append(load_cls)
patch_from_pretrained(load_cls, local_rank)
yield
### CLEANUP #####
for load_cls in patched:
load_cls.from_pretrained = load_cls.from_pretrained_orig
del load_cls.from_pretrained_orig
### START OF TEST
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.datasets import get_rank_partition
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
with ct_offload(): # <- context manager to wrap from_pretrained
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto", device_map="cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512
#### CHANGE
ds = load_dataset(
DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
)
# ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"],tokenize=False,)}
ds = ds.map(preprocess)
# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
recipe = None
import time
start = time.time()
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
pipeline="sequential",
)
print(f"\nPipeline took {time.time()-start} seconds, rank={dist.get_rank()}")
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Peak GPU Memory: {peak_memory_gb:.2f} GB, rank={dist.get_rank()}\n")
dist.destroy_process_group()
# torchrun --nproc_per_node=1 test_ddp.py 2>&1 | tee run_1.log
# # Pipeline took 55.957462310791016 seconds, rank=0
# # Peak GPU Memory: 1.45 GB, rank=0
# torchrun --nproc_per_node=2 test_ddp.py 2>&1 | tee run_2.log
# # Pipeline took 39.69138813018799 seconds, rank=0
# # Peak GPU Memory: 1.45 GB, rank=0
# # Pipeline took 40.09099316596985 seconds, rank=1
# # Peak GPU Memory: 1.45 GB, rank=1
# torchrun --nproc_per_node=3 test_ddp.py 2>&1 | tee run_3.log
# # Pipeline took 32.77818274497986 seconds, rank=0
# # Peak GPU Memory: 1.45 GB, rank=0
# # Pipeline took 33.59840416908264 seconds, rank=1
# # Peak GPU Memory: 1.45 GB, rank=1
# # Pipeline took 33.93803858757019 seconds, rank=2
# # Peak GPU Memory: 1.45 GB, rank=2
# torchrun --nproc_per_node=4 test_ddp.py 2>&1 | tee run_4.log
# # Pipeline took 29.62357449531555 seconds, rank=0
# # Peak GPU Memory: 1.45 GB, rank=0
# # Pipeline took 30.166877269744873 seconds, rank=3
# # Peak GPU Memory: 1.45 GB, rank=3
# # Pipeline took 30.52944254875183 seconds, rank=1
# # Peak GPU Memory: 1.45 GB, rank=1
# # Pipeline took 30.486549377441406 seconds, rank=2
# # Peak GPU Memory: 1.45 GB, rank=2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment