Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save basilwong/d9485c53a3778ff495472292acd3d1c5 to your computer and use it in GitHub Desktop.

Select an option

Save basilwong/d9485c53a3778ff495472292acd3d1c5 to your computer and use it in GitHub Desktop.
Mosaic Demo_ Understanding Memory Differences with Activation Checkpointing.ipynb
{
"metadata": {
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
"bento/extensions/kernel_ui/main.css": true,
"bento/extensions/new_kernel/main.css": true,
"bento/extensions/system_usage/main.css": true,
"bento/extensions/theme/main.css": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"last_server_session_id": "6c60c0db-2032-4b20-8801-9a457fba4232",
"last_kernel_id": "6fb908c1-dfac-434f-bfd3-620af45527bb",
"last_base_url": "https://bento.edge.x2p.facebook.net/",
"last_msg_id": "f2ae425f-627cb2a601d37971fbc777b7_3319",
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0,
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/basilwong/d9485c53a3778ff495472292acd3d1c5/mosaic-demo_-understanding-memory-differences-with-activation-checkpointing.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"bentoCellName": {
"name": "Introduction",
"origin": "ai"
},
"showInput": false,
"originalKey": "8b7f98ef-d9d6-45e7-ae5a-d755773fc0d4",
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "markdown",
"id": "qpCIRBY6g-bf"
},
"source": [
"# Mosaic Demo: Understanding Memory Differences with Activation Checkpointing\n",
"\n",
"This notebook demonstrates how to use [Mosaic](https://github.com/facebookresearch/mosaic) to analyze and compare GPU memory usage between different model configurations.\n",
"\n",
"**What we'll do:**\n",
"1. Train GPT-2 and capture a memory snapshot (baseline)\n",
"2. Enable activation checkpointing and train again (modified)\n",
"3. Use Mosaic to identify exactly where memory savings occur"
]
},
{
"cell_type": "code",
"metadata": {
"bentoCellName": {
"origin": "ai",
"name": "Install Dependencies"
},
"originalKey": "ea931663-4c71-48a8-9439-d3f63635b301",
"showInput": true,
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "python",
"executionStartTime": 1767773113172,
"executionStopTime": 1767773114772,
"serverExecutionDuration": 530.33932700055,
"requestMsgId": "bc8cd5d1-1bab-4ef1-8e27-017d5f90c89f",
"outputsInitialized": true,
"output": {
"id": "1434577914764349",
"output_revision_id": "737613109387922"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aSLqIX_Rg-bg",
"outputId": "fdbca7df-4012-4a82-eae3-89b258666ab3"
},
"source": [
"!pip install -q transformers torch\n",
"!pip install -q git+https://github.com/facebookresearch/mosaic.git\n",
"\n",
"import torch\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"PyTorch version: 2.9.0+cu126\n",
"CUDA available: True\n",
"GPU: Tesla T4\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"bentoCellName": {
"name": "Helper Functions",
"origin": "ai"
},
"originalKey": "cf8fe7be-1ab3-4591-a14b-597c3655018f",
"showInput": true,
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "python",
"executionStartTime": 1767773274369,
"executionStopTime": 1767773275875,
"serverExecutionDuration": 964.08266800063,
"requestMsgId": "38217f9b-5c10-4c5a-bd2f-f1accac70513",
"outputsInitialized": true,
"id": "FTgyVocLg-bh"
},
"source": [
"import torch\n",
"import pickle\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
"from contextlib import contextmanager\n",
"\n",
"class RandomTokenDataset(Dataset):\n",
" \"\"\"Generates random token sequences for training.\"\"\"\n",
" def __init__(self, vocab_size, seq_length=512, num_samples=100):\n",
" self.vocab_size = vocab_size\n",
" self.seq_length = seq_length\n",
" self.num_samples = num_samples\n",
"\n",
" def __len__(self):\n",
" return self.num_samples\n",
"\n",
" def __getitem__(self, idx):\n",
" input_ids = torch.randint(0, self.vocab_size, (self.seq_length,))\n",
" return {\"input_ids\": input_ids, \"labels\": input_ids.clone()}\n",
"\n",
"@contextmanager\n",
"def capture_memory_snapshot(output_path):\n",
" \"\"\"Context manager to capture and save PyTorch memory snapshot.\"\"\"\n",
" torch.cuda.memory._record_memory_history(max_entries=100000)\n",
" try:\n",
" yield\n",
" finally:\n",
" snapshot = torch.cuda.memory._snapshot()\n",
" torch.cuda.memory._record_memory_history(enabled=None)\n",
" with open(output_path, \"wb\") as f:\n",
" pickle.dump(snapshot, f)\n",
" print(f\"✓ Memory snapshot saved to {output_path}\")\n",
"\n",
"def run_training(\n",
" activation_checkpointing: bool,\n",
" snapshot_path: str,\n",
" batch_size: int = 4,\n",
" seq_length: int = 512,\n",
" num_steps: int = 5\n",
"):\n",
" \"\"\"Run training loop and capture memory snapshot.\"\"\"\n",
"\n",
" # Clear any previous memory\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.reset_peak_memory_stats()\n",
"\n",
" device = torch.device(\"cuda\")\n",
"\n",
" # Load model\n",
" print(f\"Loading GPT-2 (activation_checkpointing={activation_checkpointing})...\")\n",
" model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
"\n",
" if activation_checkpointing:\n",
" model.gradient_checkpointing_enable()\n",
" print(\"✓ Activation checkpointing ENABLED\")\n",
" else:\n",
" print(\"✗ Activation checkpointing DISABLED\")\n",
"\n",
" model = model.to(device)\n",
" model.train()\n",
"\n",
" # Create dataset and dataloader\n",
" tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
" dataset = RandomTokenDataset(\n",
" vocab_size=tokenizer.vocab_size,\n",
" seq_length=seq_length,\n",
" num_samples=100\n",
" )\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
"\n",
" # Setup optimizer\n",
" optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)\n",
"\n",
" # Training loop with memory capture\n",
" print(f\"Running {num_steps} training steps...\")\n",
"\n",
" with capture_memory_snapshot(snapshot_path):\n",
" for step, batch in enumerate(dataloader):\n",
" if step >= num_steps:\n",
" break\n",
"\n",
" batch = {k: v.to(device) for k, v in batch.items()}\n",
"\n",
" optimizer.zero_grad()\n",
" outputs = model(input_ids=batch[\"input_ids\"], labels=batch[\"labels\"])\n",
" loss = outputs.loss\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print(f\" Step {step + 1}/{num_steps}, Loss: {loss.item():.4f}\")\n",
"\n",
" peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)\n",
" print(f\"✓ Peak GPU memory: {peak_memory_gb:.2f} GB\")\n",
"\n",
" # Cleanup\n",
" del model, optimizer\n",
" torch.cuda.empty_cache()\n",
"\n",
" return peak_memory_gb"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"bentoCellName": {
"origin": "ai",
"name": "Run Baseline"
},
"originalKey": "a7c60a32-55aa-46bd-aff4-8b011865805a",
"showInput": true,
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "python",
"executionStartTime": 1767773286636,
"executionStopTime": 1767773288645,
"serverExecutionDuration": 1762.8840569996,
"collapsed": true,
"requestMsgId": "fcf866eb-10a7-4991-b347-e1ca4dc77871",
"output": {
"id": "1584824005867811",
"output_revision_id": "1359148835530512"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "szQMlaAKg-bh",
"outputId": "b17b35f3-a8ee-449f-adac-a7084995e155"
},
"source": [
"print(\"=\" * 60)\n",
"print(\"BASELINE: Training WITHOUT Activation Checkpointing\")\n",
"print(\"=\" * 60)\n",
"\n",
"baseline_memory = run_training(\n",
" activation_checkpointing=False,\n",
" snapshot_path=\"snapshot_baseline.pickle\",\n",
" batch_size=4,\n",
" seq_length=512,\n",
" num_steps=5\n",
")"
],
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"============================================================\n",
"BASELINE: Training WITHOUT Activation Checkpointing\n",
"============================================================\n",
"Loading GPT-2 (activation_checkpointing=False)...\n",
"✗ Activation checkpointing DISABLED\n",
"Running 5 training steps...\n",
" Step 1/5, Loss: 12.2235\n",
" Step 2/5, Loss: 12.2053\n",
" Step 3/5, Loss: 11.9615\n",
" Step 4/5, Loss: 11.8397\n",
" Step 5/5, Loss: 11.7291\n",
"✓ Memory snapshot saved to snapshot_baseline.pickle\n",
"✓ Peak GPU memory: 5.13 GB\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"bentoCellName": {
"name": "Run With GC",
"origin": "ai"
},
"originalKey": "f535b3b4-def3-45b2-8996-d71e58558429",
"showInput": true,
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "python",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1x_A-eeEg-bh",
"outputId": "3ef3be5c-ecf0-419d-bce8-3783d1512e17"
},
"source": [
"print(\"=\" * 60)\n",
"print(\"MODIFIED: Training WITH Activation Checkpointing\")\n",
"print(\"=\" * 60)\n",
"\n",
"ac_memory = run_training(\n",
" activation_checkpointing=True,\n",
" snapshot_path=\"snapshot_with_ac.pickle\",\n",
" batch_size=4,\n",
" seq_length=512,\n",
" num_steps=5\n",
")\n",
"\n",
"# Summary\n",
"print(\"\\n\" + \"=\" * 60)\n",
"print(\"MEMORY COMPARISON SUMMARY\")\n",
"print(\"=\" * 60)\n",
"print(f\"Baseline (no GC): {baseline_memory:.2f} GB\")\n",
"print(f\"With GC: {ac_memory:.2f} GB\")\n",
"print(f\"Memory Saved: {baseline_memory - ac_memory:.2f} GB ({100 * (baseline_memory - ac_memory) / baseline_memory:.1f}%)\")"
],
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"============================================================\n",
"MODIFIED: Training WITH Activation Checkpointing\n",
"============================================================\n",
"Loading GPT-2 (activation_checkpointing=True)...\n",
"✓ Activation checkpointing ENABLED\n",
"Running 5 training steps...\n",
" Step 1/5, Loss: 12.3188\n",
" Step 2/5, Loss: 12.0237\n",
" Step 3/5, Loss: 11.9146\n",
" Step 4/5, Loss: 11.7995\n",
" Step 5/5, Loss: 11.7340\n",
"✓ Memory snapshot saved to snapshot_with_ac.pickle\n",
"✓ Peak GPU memory: 3.04 GB\n",
"\n",
"============================================================\n",
"MEMORY COMPARISON SUMMARY\n",
"============================================================\n",
"Baseline (no GC): 5.13 GB\n",
"With GC: 3.04 GB\n",
"Memory Saved: 2.08 GB (40.6%)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Categorical Comparison"
],
"metadata": {
"id": "AVo9cl5hnPxz"
}
},
{
"cell_type": "code",
"metadata": {
"bentoCellName": {
"origin": "ai",
"name": "Mosaic Profiling"
},
"originalKey": "8520d5db-8e21-4298-a557-ab73c0ca2930",
"showInput": true,
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "python",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yGtpnhrHg-bh",
"outputId": "462598f5-fcc1-424d-8f6f-9637e513093d"
},
"source": [
"print(\"=\" * 60)\n",
"print(\"MOSAIC: Categorical Memory Profiling\")\n",
"print(\"=\" * 60)\n",
"\n",
"# Generate HTML profiles\n",
"!mosaic_get_memory_profile --snapshot snapshot_baseline.pickle --out-path profile_baseline.html --profile categories --preserve-allocation-order --plotter_sampling_rate 20\n",
"\n",
"print(\"\")\n",
"\n",
"!mosaic_get_memory_profile --snapshot snapshot_with_ac.pickle --out-path profile_with_gc.html --profile categories --preserve-allocation-order --plotter_sampling_rate 20\n",
"\n",
"print(\"\\n✓ Generated profile_baseline.html\")\n",
"print(\"✓ Generated profile_with_gc.html\")\n",
"print(\"\\nDownload these files to view the interactive memory profiles.\")"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"============================================================\n",
"MOSAIC: Categorical Memory Profiling\n",
"============================================================\n",
"INFO:root:Loading snapshot snapshot_baseline.pickle using io read\n",
"INFO:root:Loading snapshot snapshot_baseline.pickle, size 3.20MB ...\n",
"INFO:root:Snapshot loaded successfully.\n",
"Memory Usage At Peak:\n",
"Total Allocated: 4.62GiB\n",
"Category Profile:\n",
"AllocationType.UNKNOWN: 32.0KB\n",
"AllocationType.ACTIVATION: 2.93GiB\n",
"AllocationType.BACKWARD: 785.27MB\n",
"AllocationType.OPTIMIZER: 949.4MB\n",
"Annotation Profile:\n",
"Compile Context Profile:\n",
"Custom Profile:\n",
"INFO:root:Profiling function took: 1.4613242149353027 seconds to run\n",
"\n",
"INFO:root:Loading snapshot snapshot_with_ac.pickle using io read\n",
"INFO:root:Loading snapshot snapshot_with_ac.pickle, size 5.27MB ...\n",
"INFO:root:Snapshot loaded successfully.\n",
"Memory Usage At Peak:\n",
"Total Allocated: 2.55GiB\n",
"Category Profile:\n",
"AllocationType.UNKNOWN: 32.0KB\n",
"AllocationType.ACTIVATION: 871.79MB\n",
"AllocationType.BACKWARD: 785.27MB\n",
"AllocationType.OPTIMIZER: 949.4MB\n",
"Annotation Profile:\n",
"Compile Context Profile:\n",
"Custom Profile:\n",
"INFO:root:Profiling function took: 1.958526372909546 seconds to run\n",
"\n",
"✓ Generated profile_baseline.html\n",
"✓ Generated profile_with_gc.html\n",
"\n",
"Download these files to view the interactive memory profiles.\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"bentoCellName": {
"name": "Download Files",
"origin": "ai"
},
"originalKey": "4d107217-82d1-49c5-b36d-8f03db30a4e8",
"showInput": true,
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "python",
"id": "gA-eUPrZg-bh"
},
"source": [
"# Download the generated files (for Google Colab)\n",
"from google.colab import files\n",
"\n",
"print(\"Downloading memory snapshots and profiles...\")\n",
"files.download('snapshot_baseline.pickle')\n",
"files.download('snapshot_with_gc.pickle')\n",
"files.download('profile_baseline.html')\n",
"files.download('profile_with_gc.html')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"bentoCellName": {
"origin": "ai",
"name": "Interpretation"
},
"showInput": false,
"originalKey": "55cf1c6b-6908-4968-aaaf-08c898bfa9f4",
"customInput": null,
"bentoAICellStatus": {
"status": "pending_user_action",
"type": "CELL_ADDITION"
},
"language": "markdown",
"id": "qO2BsJvvg-bi"
},
"source": [
"## Results Interpretation\n",
"\n",
"### What We Observed\n",
"\n",
"Based on the Mosaic categorical profiling results:\n",
"\n",
"| Metric | Baseline | With Activation Checkpointing | Difference |\n",
"|--------|----------|----------------------------|------------|\n",
"| **Total Peak Memory** | **4.62 GB** | **2.55 GB** | **2.07 GB (45% reduction)** |\n",
"| Activation Memory | 2.93 GB | 872.79 MB | **2.08 GB saved (71% reduction)** |\n",
"| Backward/Gradient Memory | 793.39 MB | 785.27 MB | 8 MB (minimal change) |\n",
"| Optimizer State | 949.4 MB | 949.4 MB | No change |\n",
"| Unknown | 32 KB | 32 KB | No change |\n",
"\n",
"### Key Insights\n",
"**Primary Finding:** Activation memory dropped from **2.93 GB → 872 MB** (71% reduction), which accounts for nearly all the total memory savings.\n",
"\n",
"### Why Does This Happen?\n",
"\n",
"**Activation checkpointing** is a memory optimization technique that:\n",
"\n",
"1. **Without GC (Baseline):** All intermediate activations from the forward pass are stored in memory for use during backpropagation\n",
" - GPT-2 has 12 transformer layers\n",
" - Each layer stores multiple activations (attention outputs, MLP outputs, etc.)\n",
" - For batch_size=4, seq_length=512, this adds up quickly\n",
"\n",
"2. **With GC (Optimized):** Only activations at checkpoint boundaries are stored; intermediate activations are recomputed during the backward pass\n",
" - Dramatically reduces activation memory (71% in our case)\n",
" - Other memory categories remain unchanged\n",
"\n",
"### How Mosaic Helped\n",
"\n",
"Mosaic's categorical profiling immediately identified:\n",
"- Activation memory is the category with the largest difference (2.08 GB saved)\n",
"- Backward/Gradient memory stayed nearly constant (793 MB → 785 MB)\n",
"- Optimizer state remained unchanged (949 MB)\n",
" - expected since model parameters don't change\n",
"\n",
"**Without Mosaic:** You would need to manually instrument your code, track allocations, and categorize them yourself.\n",
"\n",
"**With Mosaic:** You get instant categorical breakdowns with exact numbers, making it trivial to identify/quantify memory optimizations."
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment