Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Last active February 3, 2026 20:55
Show Gist options
  • Select an option

  • Save HDCharles/7cbfc1a1b5916422294e6ad151e2204e to your computer and use it in GitHub Desktop.

Select an option

Save HDCharles/7cbfc1a1b5916422294e6ad151e2204e to your computer and use it in GitHub Desktop.
'Current API'
init_dist()
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
with ct_offload(): # <- context manager to wrap from_pretrained
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto", device_map="cpu")
<load data and quantize model>
with ct_offload():
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-asym"
model.save_pretrained(SAVE_DIR, save_compressed=True)
dist.destroy_process_group()
'''
Question: What is the optimal API for data sampling?
Problem: how can we shard the data without requiring the user to add code in yet more places and without first loading n copies
of the whole dataset on cpu
'''
DATASET_ID = "HuggingFaceH4/ultrachat_200"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512
'Option 1 - explicit'
split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
ds = load_dataset(DATASET_ID, split=split)
'# Pro: clear'
'# Con: users will miss this and make it our problem'
'Option 2 - patch load_dataset'
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
'How do we patch it?'
'Option 2a - context manager to patch load_dataset'
with ct_offload():
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
'Pro: its what were doing for save/load model'
'Con: user has to put the same/different context managers in 3 places, also hacky'
'Option 2b - we write our own load_dataset that wraps the normal one'
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
'Pro: will be the same for dist and not dist'
'Con: users will miss this too'
Option 2c - single context manager for load/save/load_dataset
'Pro: little bit cleaner'
'Con: still hacky and maybe brittle'
'Option 3 - n loads and handle it in oneshot'
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
'Pro: clean'
'Con: OOM/slow'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment