Last active
February 3, 2026 20:55
-
-
Save HDCharles/7cbfc1a1b5916422294e6ad151e2204e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 'Current API' | |
| init_dist() | |
| MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| with ct_offload(): # <- context manager to wrap from_pretrained | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto", device_map="cpu") | |
| <load data and quantize model> | |
| with ct_offload(): | |
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-asym" | |
| model.save_pretrained(SAVE_DIR, save_compressed=True) | |
| dist.destroy_process_group() | |
| ''' | |
| Question: What is the optimal API for data sampling? | |
| Problem: how can we shard the data without requiring the user to add code in yet more places and without first loading n copies | |
| of the whole dataset on cpu | |
| ''' | |
| DATASET_ID = "HuggingFaceH4/ultrachat_200" | |
| DATASET_SPLIT = "train_sft" | |
| NUM_CALIBRATION_SAMPLES = 256 | |
| MAX_SEQUENCE_LENGTH = 512 | |
| 'Option 1 - explicit' | |
| split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES) | |
| ds = load_dataset(DATASET_ID, split=split) | |
| '# Pro: clear' | |
| '# Con: users will miss this and make it our problem' | |
| 'Option 2 - patch load_dataset' | |
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | |
| 'How do we patch it?' | |
| 'Option 2a - context manager to patch load_dataset' | |
| with ct_offload(): | |
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | |
| 'Pro: its what were doing for save/load model' | |
| 'Con: user has to put the same/different context managers in 3 places, also hacky' | |
| 'Option 2b - we write our own load_dataset that wraps the normal one' | |
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | |
| 'Pro: will be the same for dist and not dist' | |
| 'Con: users will miss this too' | |
| Option 2c - single context manager for load/save/load_dataset | |
| 'Pro: little bit cleaner' | |
| 'Con: still hacky and maybe brittle' | |
| 'Option 3 - n loads and handle it in oneshot' | |
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | |
| 'Pro: clean' | |
| 'Con: OOM/slow' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment