Created
February 10, 2026 18:17
-
-
Save wojtyniak/dec05ee57eeaa7cf5dffa0fd5f4579a2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# CWM: An Open-Weights LLM for Research on Code Generation with World Models\n", | |
| "\n", | |
| "**Paper Authors:** FAIR CodeGen team\n", | |
| "\n", | |
| "This notebook provides an educational demonstration of the key computational workflows described in the CWM paper. CWM is a 32B parameter language model designed for code generation with world modeling capabilities - the ability to predict how code will execute before running it.\n", | |
| "\n", | |
| "## Overview\n", | |
| "\n", | |
| "The paper presents several key innovations:\n", | |
| "\n", | |
| "1. **Python Execution Traces**: Training data that captures how code executes line-by-line\n", | |
| "2. **ForagerAgent**: Agentic trajectories of software engineering tasks\n", | |
| "3. **Multi-task RL Training**: Reinforcement learning across coding, math, and SWE tasks\n", | |
| "4. **Executable Repository Images**: Docker-based environments for testing code\n", | |
| "\n", | |
| "**Note on Resource Constraints:**\n", | |
| "This notebook provides educational demonstrations using small-scale examples. Full-scale training of CWM requires:\n", | |
| "- Pre-training: 8T tokens on large GPU clusters\n", | |
| "- Mid-training: 5T tokens with 131k context length\n", | |
| "- RL training: 172B tokens across multiple environments\n", | |
| "\n", | |
| "Our demonstrations focus on understanding the methodology, not replicating the full experiments.\n", | |
| "\n", | |
| "## Notebook Structure\n", | |
| "\n", | |
| "1. Setup and Dependencies\n", | |
| "2. Python Execution Trace Generation\n", | |
| "3. Competitive Programming Environment\n", | |
| "4. Agentic Coding Workflow\n", | |
| "5. Code Understanding Evaluation (CruxEval-style)\n", | |
| "6. Scaling Guidance and Conclusions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": "## 1. Setup and Dependencies\n\n**Installation Notes:**\n- This notebook works with Python 3.6+ standard library\n- Optional dependencies for visualizations: `numpy`, `pandas`, `matplotlib`\n- If you see import warnings, install with: `pip install numpy pandas matplotlib`\n- All core functionality works without these dependencies" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Install only essential dependencies (using standard library where possible)\n", | |
| "# Uncomment the line below if running in a new environment without these packages\n", | |
| "# !pip install numpy pandas matplotlib\n", | |
| "\n", | |
| "print(\"Note: This notebook uses minimal dependencies.\")\n", | |
| "print(\"Required: numpy, pandas, matplotlib (usually pre-installed in most environments)\")\n", | |
| "print(\"If you see import errors in the next cell, run: pip install numpy pandas matplotlib\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Import required libraries\n", | |
| "import ast\n", | |
| "import sys\n", | |
| "import json\n", | |
| "import random\n", | |
| "import inspect\n", | |
| "import traceback\n", | |
| "from typing import Dict, List, Any, Tuple\n", | |
| "from dataclasses import dataclass\n", | |
| "from io import StringIO\n", | |
| "from collections import defaultdict\n", | |
| "\n", | |
| "# Try to import optional dependencies\n", | |
| "try:\n", | |
| " import numpy as np\n", | |
| " import pandas as pd\n", | |
| " import matplotlib.pyplot as plt\n", | |
| " HAS_SCIENTIFIC = True\n", | |
| "except ImportError:\n", | |
| " print(\"Warning: numpy/pandas/matplotlib not available. Some visualizations will be skipped.\")\n", | |
| " print(\"Install with: pip install numpy pandas matplotlib\")\n", | |
| " HAS_SCIENTIFIC = False\n", | |
| " # Create dummy objects for compatibility\n", | |
| " class DummyNumpy:\n", | |
| " def random(self): return self\n", | |
| " def seed(self, x): pass\n", | |
| " def mean(self, x): return sum(x) / len(x) if x else 0\n", | |
| " def array(self, x): return x\n", | |
| " np = DummyNumpy()\n", | |
| " pd = None\n", | |
| " plt = None\n", | |
| "\n", | |
| "# Set random seeds for reproducibility\n", | |
| "random.seed(42)\n", | |
| "if HAS_SCIENTIFIC:\n", | |
| " np.random.seed(42)\n", | |
| "\n", | |
| "print(\"✓ All imports successful\")\n", | |
| "print(f\"Python version: {sys.version}\")\n", | |
| "print(f\"Scientific libraries available: {HAS_SCIENTIFIC}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 2. Python Execution Trace Generation\n", | |
| "\n", | |
| "**Reference:** Section 2.2 of the paper\n", | |
| "\n", | |
| "CWM is trained on Python execution traces that capture how code executes line-by-line, including the values of local variables at each step. This is a key innovation that enables world modeling - the ability to predict program execution.\n", | |
| "\n", | |
| "### How it works:\n", | |
| "- Instrument Python functions to track execution\n", | |
| "- Capture local variable states after each line\n", | |
| "- Format as observation-action pairs\n", | |
| "- Train the model to predict these traces\n", | |
| "\n", | |
| "The paper collected:\n", | |
| "- **120M+ function-level traces** from online Python code\n", | |
| "- **350k CodeContests solution traces** with correct/incorrect balance\n", | |
| "- **Repository-level traces** from 21k+ repositories" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class PythonExecutionTracer:\n", | |
| " \"\"\"Simple execution tracer for Python functions.\n", | |
| " \n", | |
| " This demonstrates the concept of capturing execution traces\n", | |
| " similar to CWM's training data generation (Workflow 9).\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(self):\n", | |
| " self.trace = []\n", | |
| " self.current_line = None\n", | |
| " \n", | |
| " def trace_function(self, frame, event, arg):\n", | |
| " \"\"\"Trace function called by sys.settrace.\"\"\"\n", | |
| " if event == 'line':\n", | |
| " # Capture local variables at this line\n", | |
| " local_vars = {k: self._serialize_value(v) \n", | |
| " for k, v in frame.f_locals.items()\n", | |
| " if not k.startswith('_')}\n", | |
| " \n", | |
| " self.trace.append({\n", | |
| " 'line': frame.f_lineno,\n", | |
| " 'locals': local_vars,\n", | |
| " 'code': frame.f_code.co_name\n", | |
| " })\n", | |
| " elif event == 'return':\n", | |
| " # Capture return value\n", | |
| " self.trace.append({\n", | |
| " 'event': 'return',\n", | |
| " 'value': self._serialize_value(arg)\n", | |
| " })\n", | |
| " return self.trace_function\n", | |
| " \n", | |
| " def _serialize_value(self, value):\n", | |
| " \"\"\"Serialize values for display.\"\"\"\n", | |
| " if isinstance(value, (int, float, str, bool, type(None))):\n", | |
| " return value\n", | |
| " elif isinstance(value, (list, tuple)):\n", | |
| " return [self._serialize_value(v) for v in value[:5]] # Limit length\n", | |
| " elif isinstance(value, dict):\n", | |
| " return {k: self._serialize_value(v) for k, v in list(value.items())[:5]}\n", | |
| " else:\n", | |
| " return str(type(value).__name__)\n", | |
| " \n", | |
| " def trace_execution(self, func, *args, **kwargs):\n", | |
| " \"\"\"Execute function with tracing enabled.\"\"\"\n", | |
| " self.trace = []\n", | |
| " sys.settrace(self.trace_function)\n", | |
| " try:\n", | |
| " result = func(*args, **kwargs)\n", | |
| " finally:\n", | |
| " sys.settrace(None)\n", | |
| " return result, self.trace\n", | |
| "\n", | |
| "# Example function to trace\n", | |
| "def fibonacci(n):\n", | |
| " \"\"\"Compute nth Fibonacci number.\"\"\"\n", | |
| " if n <= 1:\n", | |
| " return n\n", | |
| " a, b = 0, 1\n", | |
| " for i in range(2, n + 1):\n", | |
| " a, b = b, a + b\n", | |
| " return b\n", | |
| "\n", | |
| "# Trace execution\n", | |
| "tracer = PythonExecutionTracer()\n", | |
| "result, trace = tracer.trace_execution(fibonacci, 7)\n", | |
| "\n", | |
| "print(f\"Result: {result}\")\n", | |
| "print(f\"\\nExecution trace ({len(trace)} steps):\")\n", | |
| "for i, step in enumerate(trace[:10]): # Show first 10 steps\n", | |
| " print(f\"Step {i}: {json.dumps(step, indent=2)}\")\n", | |
| "if len(trace) > 10:\n", | |
| " print(f\"... ({len(trace) - 10} more steps)\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Generate multiple execution traces with different inputs\n", | |
| "print(\"Generating execution traces for training data...\\n\")\n", | |
| "\n", | |
| "def generate_trace_dataset(func, inputs, max_traces=5):\n", | |
| " \"\"\"Generate multiple execution traces for a function.\n", | |
| " \n", | |
| " In CWM, this is done at scale:\n", | |
| " - 120M+ functions traced\n", | |
| " - Inputs generated via fuzzing and LLM prompting\n", | |
| " - Traces formatted as observation-action pairs\n", | |
| " \"\"\"\n", | |
| " tracer = PythonExecutionTracer()\n", | |
| " traces = []\n", | |
| " \n", | |
| " for inp in inputs[:max_traces]:\n", | |
| " result, trace = tracer.trace_execution(func, inp)\n", | |
| " traces.append({\n", | |
| " 'input': inp,\n", | |
| " 'output': result,\n", | |
| " 'trace_length': len(trace),\n", | |
| " 'trace': trace\n", | |
| " })\n", | |
| " \n", | |
| " return traces\n", | |
| "\n", | |
| "# Example: trace multiple inputs\n", | |
| "test_inputs = [3, 5, 7, 10, 12]\n", | |
| "traces = generate_trace_dataset(fibonacci, test_inputs)\n", | |
| "\n", | |
| "# Display trace statistics\n", | |
| "trace_data = [{'Input': t['input'], 'Output': t['output'], 'Trace Steps': t['trace_length']} \n", | |
| " for t in traces]\n", | |
| "\n", | |
| "if HAS_SCIENTIFIC and pd:\n", | |
| " df = pd.DataFrame(trace_data)\n", | |
| " print(\"Execution Trace Dataset:\")\n", | |
| " print(df)\n", | |
| " print(f\"\\nAverage trace length: {df['Trace Steps'].mean():.1f} steps\")\n", | |
| "else:\n", | |
| " print(\"Execution Trace Dataset:\")\n", | |
| " for row in trace_data:\n", | |
| " print(f\" Input: {row['Input']}, Output: {row['Output']}, Trace Steps: {row['Trace Steps']}\")\n", | |
| " avg_steps = sum(t['trace_length'] for t in traces) / len(traces)\n", | |
| " print(f\"\\nAverage trace length: {avg_steps:.1f} steps\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### CWM Trace Format\n", | |
| "\n", | |
| "The paper uses special tokens to format traces:\n", | |
| "- `<|observation_sep|>`: Separates observations (variable states)\n", | |
| "- `<|action_sep|>`: Separates actions (code lines)\n", | |
| "- `<|return_sep|>`: Marks return values\n", | |
| "\n", | |
| "Let's format our traces in a similar style:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def format_trace_cwm_style(trace_data):\n", | |
| " \"\"\"Format trace in CWM-style with special tokens.\n", | |
| " \n", | |
| " This demonstrates the observation-action format used in CWM training.\n", | |
| " \"\"\"\n", | |
| " formatted = []\n", | |
| " formatted.append(f\"# Input: {trace_data['input']}\")\n", | |
| " \n", | |
| " for step in trace_data['trace']:\n", | |
| " if 'locals' in step:\n", | |
| " # Observation: current variable states\n", | |
| " formatted.append(\"<|observation_sep|>\")\n", | |
| " formatted.append(f\"Line {step['line']}: {json.dumps(step['locals'])}\")\n", | |
| " formatted.append(\"<|action_sep|>\")\n", | |
| " elif step.get('event') == 'return':\n", | |
| " # Return value\n", | |
| " formatted.append(\"<|return_sep|>\")\n", | |
| " formatted.append(f\"Return: {step['value']}\")\n", | |
| " \n", | |
| " return '\\n'.join(formatted)\n", | |
| "\n", | |
| "# Example formatted trace\n", | |
| "example_trace = traces[2] # n=7\n", | |
| "formatted = format_trace_cwm_style(example_trace)\n", | |
| "print(\"CWM-style Formatted Trace:\")\n", | |
| "print(formatted[:500] + \"\\n...\" if len(formatted) > 500 else formatted)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Scaling to Production\n", | |
| "\n", | |
| "**How CWM scales this up (from Section 2.2):**\n", | |
| "\n", | |
| "1. **Function-level tracing (120M+ functions)**:\n", | |
| " - Collect Python functions from online sources\n", | |
| " - Generate inputs via fuzzing + Llama-3-70B-Instruct prompting\n", | |
| " - Trace execution with full variable states\n", | |
| "\n", | |
| "2. **Repository-level tracing (21k+ repos)**:\n", | |
| " - Build executable Docker images with RepoAgent/Activ\n", | |
| " - Trace unit test execution\n", | |
| " - Extract function-level traces with configurable stack depth\n", | |
| "\n", | |
| "3. **CodeContests solution tracing (350k)**:\n", | |
| " - Generate solutions with Llama-3.1-70B-Instruct \n", | |
| " - Balance correct/incorrect submissions\n", | |
| " - Trace execution on test cases\n", | |
| "\n", | |
| "Total mid-training data: **5T tokens** including traces and ForagerAgent trajectories" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 3. Competitive Programming RL Environment\n", | |
| "\n", | |
| "**Reference:** Section 5.3.2 (Workflow 4)\n", | |
| "\n", | |
| "CWM uses reinforcement learning on competitive programming tasks with verifiable rewards from unit tests. This section demonstrates the workflow:\n", | |
| "\n", | |
| "1. Load problem with test cases\n", | |
| "2. Generate code solution\n", | |
| "3. Execute and verify against tests\n", | |
| "4. Compute reward signal\n", | |
| "5. Update policy via GRPO (Group Relative Policy Optimization)\n", | |
| "\n", | |
| "We'll implement a simplified version that shows the key concepts." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@dataclass\n", | |
| "class CompetitiveProgrammingProblem:\n", | |
| " \"\"\"Represents a competitive programming problem.\"\"\"\n", | |
| " name: str\n", | |
| " description: str\n", | |
| " test_cases: List[Tuple[Any, Any]] # (input, expected_output)\n", | |
| " \n", | |
| "# Create example problems\n", | |
| "problems = [\n", | |
| " CompetitiveProgrammingProblem(\n", | |
| " name=\"Sum of Two Numbers\",\n", | |
| " description=\"Given two integers a and b, return their sum.\",\n", | |
| " test_cases=[\n", | |
| " ((2, 3), 5),\n", | |
| " ((0, 0), 0),\n", | |
| " ((-1, 1), 0),\n", | |
| " ((100, 200), 300),\n", | |
| " ]\n", | |
| " ),\n", | |
| " CompetitiveProgrammingProblem(\n", | |
| " name=\"Reverse String\",\n", | |
| " description=\"Given a string s, return it reversed.\",\n", | |
| " test_cases=[\n", | |
| " ((\"hello\",), \"olleh\"),\n", | |
| " ((\"a\",), \"a\"),\n", | |
| " ((\"\",), \"\"),\n", | |
| " ((\"12345\",), \"54321\"),\n", | |
| " ]\n", | |
| " ),\n", | |
| " CompetitiveProgrammingProblem(\n", | |
| " name=\"Find Maximum\",\n", | |
| " description=\"Given a list of integers, return the maximum value.\",\n", | |
| " test_cases=[\n", | |
| " (([1, 2, 3],), 3),\n", | |
| " (([5],), 5),\n", | |
| " (([-1, -5, -3],), -1),\n", | |
| " (([0, 0, 0],), 0),\n", | |
| " ]\n", | |
| " ),\n", | |
| "]\n", | |
| "\n", | |
| "print(f\"Loaded {len(problems)} problems for RL environment\")\n", | |
| "for p in problems:\n", | |
| " print(f\" - {p.name}: {len(p.test_cases)} test cases\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class CodeExecutionEnvironment:\n", | |
| " \"\"\"Simulated environment for executing and verifying code solutions.\n", | |
| " \n", | |
| " In CWM's actual implementation (Section 5.3.2):\n", | |
| " - Uses Docker containers for isolation\n", | |
| " - Supports Python, C++, Rust, Go, Java, JavaScript\n", | |
| " - Compiles and executes with timeout protection\n", | |
| " - Returns detailed execution feedback\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(self, timeout=2):\n", | |
| " self.timeout = timeout\n", | |
| " \n", | |
| " def execute_solution(self, code: str, test_cases: List[Tuple], func_name: str = \"solution\"):\n", | |
| " \"\"\"Execute code solution against test cases.\n", | |
| " \n", | |
| " Returns:\n", | |
| " results: List of (passed, actual_output, expected_output) for each test\n", | |
| " \"\"\"\n", | |
| " results = []\n", | |
| " \n", | |
| " try:\n", | |
| " # Create local namespace and execute code\n", | |
| " namespace = {}\n", | |
| " exec(code, namespace)\n", | |
| " \n", | |
| " if func_name not in namespace:\n", | |
| " return [(False, f\"Function '{func_name}' not found\", None) for _ in test_cases]\n", | |
| " \n", | |
| " func = namespace[func_name]\n", | |
| " \n", | |
| " # Run each test case\n", | |
| " for test_input, expected in test_cases:\n", | |
| " try:\n", | |
| " # Execute with timeout protection\n", | |
| " actual = func(*test_input)\n", | |
| " passed = actual == expected\n", | |
| " results.append((passed, actual, expected))\n", | |
| " except Exception as e:\n", | |
| " results.append((False, f\"Error: {str(e)}\", expected))\n", | |
| " \n", | |
| " except SyntaxError as e:\n", | |
| " results = [(False, f\"Syntax Error: {str(e)}\", None) for _ in test_cases]\n", | |
| " except Exception as e:\n", | |
| " results = [(False, f\"Execution Error: {str(e)}\", None) for _ in test_cases]\n", | |
| " \n", | |
| " return results\n", | |
| " \n", | |
| " def compute_reward(self, results: List[Tuple[bool, Any, Any]]) -> float:\n", | |
| " \"\"\"Compute RL reward from test results.\n", | |
| " \n", | |
| " CWM uses binary rewards (0 or 1) based on all tests passing.\n", | |
| " We'll use a graded reward based on fraction of tests passed.\n", | |
| " \"\"\"\n", | |
| " num_passed = sum(1 for passed, _, _ in results if passed)\n", | |
| " return num_passed / len(results)\n", | |
| "\n", | |
| "# Test the environment\n", | |
| "env = CodeExecutionEnvironment()\n", | |
| "\n", | |
| "# Example solution (correct)\n", | |
| "correct_solution = '''\n", | |
| "def solution(a, b):\n", | |
| " return a + b\n", | |
| "'''\n", | |
| "\n", | |
| "# Example solution (incorrect)\n", | |
| "incorrect_solution = '''\n", | |
| "def solution(a, b):\n", | |
| " return a - b # Wrong operation!\n", | |
| "'''\n", | |
| "\n", | |
| "print(\"Testing Code Execution Environment:\")\n", | |
| "print(\"\\nCorrect solution:\")\n", | |
| "results = env.execute_solution(correct_solution, problems[0].test_cases)\n", | |
| "reward = env.compute_reward(results)\n", | |
| "print(f\" Passed: {sum(r[0] for r in results)}/{len(results)} tests\")\n", | |
| "print(f\" Reward: {reward:.2f}\")\n", | |
| "\n", | |
| "print(\"\\nIncorrect solution:\")\n", | |
| "results = env.execute_solution(incorrect_solution, problems[0].test_cases)\n", | |
| "reward = env.compute_reward(results)\n", | |
| "print(f\" Passed: {sum(r[0] for r in results)}/{len(results)} tests\")\n", | |
| "print(f\" Reward: {reward:.2f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Simulating the RL Training Loop\n", | |
| "\n", | |
| "CWM uses GRPO (Group Relative Policy Optimization) for RL training:\n", | |
| "- Similar to PPO but without a separate value model\n", | |
| "- Uses Monte Carlo value estimation\n", | |
| "- Trained on 172B tokens across multiple environments\n", | |
| "\n", | |
| "We'll simulate a simplified version to show the concept:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class SimpleRLAgent:\n", | |
| " \"\"\"Simplified RL agent for demonstration.\n", | |
| " \n", | |
| " Note: This is a toy example. Real CWM uses:\n", | |
| " - 32B parameter Transformer model\n", | |
| " - GRPO algorithm with PPO loss\n", | |
| " - Distributed training across multiple GPUs\n", | |
| " - 172B tokens of RL training\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(self):\n", | |
| " # Store solution templates (in real CWM, this is a neural network)\n", | |
| " self.solution_templates = {\n", | |
| " \"sum\": \"def solution(a, b):\\n return a + b\",\n", | |
| " \"reverse\": \"def solution(s):\\n return s[::-1]\",\n", | |
| " \"max\": \"def solution(lst):\\n return max(lst)\",\n", | |
| " }\n", | |
| " self.performance_history = []\n", | |
| " \n", | |
| " def generate_solution(self, problem: CompetitiveProgrammingProblem) -> str:\n", | |
| " \"\"\"Generate a code solution for the problem.\n", | |
| " \n", | |
| " In real CWM: model generates code with reasoning traces.\n", | |
| " \"\"\"\n", | |
| " # Simple heuristic to select template based on problem name\n", | |
| " if \"sum\" in problem.name.lower():\n", | |
| " return self.solution_templates[\"sum\"]\n", | |
| " elif \"reverse\" in problem.name.lower():\n", | |
| " return self.solution_templates[\"reverse\"]\n", | |
| " elif \"max\" in problem.name.lower():\n", | |
| " return self.solution_templates[\"max\"]\n", | |
| " else:\n", | |
| " return \"def solution(*args):\\n return None\" # Fallback\n", | |
| " \n", | |
| " def train_step(self, problem: CompetitiveProgrammingProblem, env: CodeExecutionEnvironment):\n", | |
| " \"\"\"Single training step: generate solution, get reward, update policy.\"\"\"\n", | |
| " # Generate solution\n", | |
| " solution = self.generate_solution(problem)\n", | |
| " \n", | |
| " # Execute and evaluate\n", | |
| " results = env.execute_solution(solution, problem.test_cases)\n", | |
| " reward = env.compute_reward(results)\n", | |
| " \n", | |
| " # In real GRPO: compute policy gradient and update model weights\n", | |
| " # Here we just track performance\n", | |
| " self.performance_history.append({\n", | |
| " 'problem': problem.name,\n", | |
| " 'reward': reward,\n", | |
| " 'tests_passed': sum(r[0] for r in results),\n", | |
| " 'total_tests': len(results)\n", | |
| " })\n", | |
| " \n", | |
| " return reward, results\n", | |
| "\n", | |
| "# Simulate RL training\n", | |
| "agent = SimpleRLAgent()\n", | |
| "env = CodeExecutionEnvironment()\n", | |
| "\n", | |
| "print(\"Simulating RL Training Loop:\\n\")\n", | |
| "for i, problem in enumerate(problems):\n", | |
| " reward, results = agent.train_step(problem, env)\n", | |
| " passed = sum(r[0] for r in results)\n", | |
| " print(f\"Problem {i+1}: {problem.name}\")\n", | |
| " print(f\" Tests passed: {passed}/{len(results)}\")\n", | |
| " print(f\" Reward: {reward:.2f}\")\n", | |
| "\n", | |
| "# Show training statistics\n", | |
| "avg_reward = np.mean([h['reward'] for h in agent.performance_history])\n", | |
| "print(f\"\\nAverage reward: {avg_reward:.2f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 4. Agentic Coding Workflow\n", | |
| "\n", | |
| "**Reference:** Section 5.3.3 (Workflow 1) and Section 2.3 (ForagerAgent)\n", | |
| "\n", | |
| "One of CWM's key innovations is training on **agentic trajectories** - multi-step interactions where an agent:\n", | |
| "1. Reasons about a problem\n", | |
| "2. Creates code files\n", | |
| "3. Tests the code\n", | |
| "4. Observes execution results\n", | |
| "5. Edits code to fix errors\n", | |
| "6. Iterates until successful\n", | |
| "\n", | |
| "The ForagerAgent collected **3M trajectories** from **10.2k repositories** with actions: create, edit, bash, submit." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@dataclass\n", | |
| "class AgenticAction:\n", | |
| " \"\"\"Represents an action taken by the agent.\"\"\"\n", | |
| " action_type: str # 'create', 'edit', 'bash', 'submit'\n", | |
| " content: str\n", | |
| " observation: str = \"\" # Environment response\n", | |
| "\n", | |
| "class AgenticCodingEnvironment:\n", | |
| " \"\"\"Simulates an agentic coding environment.\n", | |
| " \n", | |
| " In CWM (Section 5.3.3 and 2.3):\n", | |
| " - Docker containers provide isolated execution\n", | |
| " - Agent has access to create, edit, bash, submit tools\n", | |
| " - Environment returns observations after each action\n", | |
| " - Trained with hybrid reward: test results + patch similarity\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(self):\n", | |
| " self.files = {} # Simulated file system\n", | |
| " self.trajectory = [] # Action-observation history\n", | |
| " \n", | |
| " def create_file(self, filename: str, content: str) -> str:\n", | |
| " \"\"\"Create a new file.\"\"\"\n", | |
| " self.files[filename] = content\n", | |
| " return f\"Created {filename}\"\n", | |
| " \n", | |
| " def edit_file(self, filename: str, new_content: str) -> str:\n", | |
| " \"\"\"Edit an existing file.\"\"\"\n", | |
| " if filename not in self.files:\n", | |
| " return f\"Error: {filename} does not exist\"\n", | |
| " self.files[filename] = new_content\n", | |
| " return f\"Edited {filename}\"\n", | |
| " \n", | |
| " def bash(self, command: str) -> str:\n", | |
| " \"\"\"Execute a bash command (simulated).\"\"\"\n", | |
| " if command.startswith(\"python\"):\n", | |
| " # Simulate running Python file\n", | |
| " parts = command.split()\n", | |
| " if len(parts) > 1:\n", | |
| " filename = parts[1]\n", | |
| " if filename in self.files:\n", | |
| " # Try to execute\n", | |
| " try:\n", | |
| " namespace = {}\n", | |
| " exec(self.files[filename], namespace)\n", | |
| " return \"Execution successful\"\n", | |
| " except Exception as e:\n", | |
| " return f\"Error: {str(e)}\"\n", | |
| " else:\n", | |
| " return f\"Error: {filename} not found\"\n", | |
| " return f\"Executed: {command}\"\n", | |
| " \n", | |
| " def execute_action(self, action: AgenticAction) -> str:\n", | |
| " \"\"\"Execute an agent action and return observation.\"\"\"\n", | |
| " if action.action_type == 'create':\n", | |
| " # Parse filename and content from action.content\n", | |
| " lines = action.content.split('\\n', 1)\n", | |
| " filename = lines[0].strip()\n", | |
| " content = lines[1] if len(lines) > 1 else \"\"\n", | |
| " obs = self.create_file(filename, content)\n", | |
| " elif action.action_type == 'edit':\n", | |
| " lines = action.content.split('\\n', 1)\n", | |
| " filename = lines[0].strip()\n", | |
| " content = lines[1] if len(lines) > 1 else \"\"\n", | |
| " obs = self.edit_file(filename, content)\n", | |
| " elif action.action_type == 'bash':\n", | |
| " obs = self.bash(action.content)\n", | |
| " elif action.action_type == 'submit':\n", | |
| " obs = \"Solution submitted\"\n", | |
| " else:\n", | |
| " obs = f\"Unknown action: {action.action_type}\"\n", | |
| " \n", | |
| " action.observation = obs\n", | |
| " self.trajectory.append(action)\n", | |
| " return obs\n", | |
| "\n", | |
| "# Simulate an agentic trajectory\n", | |
| "env = AgenticCodingEnvironment()\n", | |
| "\n", | |
| "print(\"Simulating Agentic Coding Trajectory:\\n\")\n", | |
| "print(\"Task: Implement a function to compute factorial\\n\")\n", | |
| "\n", | |
| "# Action 1: Create initial solution\n", | |
| "action1 = AgenticAction(\n", | |
| " action_type='create',\n", | |
| " content=\"\"\"factorial.py\n", | |
| "def factorial(n):\n", | |
| " if n == 0:\n", | |
| " return 1\n", | |
| " return n * factorial(n - 1)\n", | |
| "\n", | |
| "# Test\n", | |
| "print(factorial(5)) # Should print 120\n", | |
| "\"\"\"\n", | |
| ")\n", | |
| "obs1 = env.execute_action(action1)\n", | |
| "print(f\"Action 1 (Create): {obs1}\")\n", | |
| "\n", | |
| "# Action 2: Test the solution\n", | |
| "action2 = AgenticAction(\n", | |
| " action_type='bash',\n", | |
| " content='python factorial.py'\n", | |
| ")\n", | |
| "obs2 = env.execute_action(action2)\n", | |
| "print(f\"Action 2 (Test): {obs2}\")\n", | |
| "\n", | |
| "# Action 3: Edit to add error handling\n", | |
| "action3 = AgenticAction(\n", | |
| " action_type='edit',\n", | |
| " content=\"\"\"factorial.py\n", | |
| "def factorial(n):\n", | |
| " if n < 0:\n", | |
| " raise ValueError(\"n must be non-negative\")\n", | |
| " if n == 0:\n", | |
| " return 1\n", | |
| " return n * factorial(n - 1)\n", | |
| "\n", | |
| "# Test\n", | |
| "print(factorial(5)) # Should print 120\n", | |
| "\"\"\"\n", | |
| ")\n", | |
| "obs3 = env.execute_action(action3)\n", | |
| "print(f\"Action 3 (Edit): {obs3}\")\n", | |
| "\n", | |
| "# Action 4: Final test\n", | |
| "action4 = AgenticAction(\n", | |
| " action_type='bash',\n", | |
| " content='python factorial.py'\n", | |
| ")\n", | |
| "obs4 = env.execute_action(action4)\n", | |
| "print(f\"Action 4 (Test): {obs4}\")\n", | |
| "\n", | |
| "# Action 5: Submit\n", | |
| "action5 = AgenticAction(\n", | |
| " action_type='submit',\n", | |
| " content='factorial.py'\n", | |
| ")\n", | |
| "obs5 = env.execute_action(action5)\n", | |
| "print(f\"Action 5 (Submit): {obs5}\")\n", | |
| "\n", | |
| "print(f\"\\nTrajectory length: {len(env.trajectory)} actions\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Format trajectory for training\n", | |
| "def format_agentic_trajectory(trajectory: List[AgenticAction]) -> str:\n", | |
| " \"\"\"Format trajectory in CWM training format.\n", | |
| " \n", | |
| " CWM trains on these trajectories to learn:\n", | |
| " - When to create vs edit files\n", | |
| " - How to use bash for testing\n", | |
| " - How to interpret error messages\n", | |
| " - When to iterate vs submit\n", | |
| " \"\"\"\n", | |
| " formatted = [\"# Agentic Trajectory\\n\"]\n", | |
| " \n", | |
| " for i, action in enumerate(trajectory):\n", | |
| " formatted.append(f\"\\n## Step {i+1}: {action.action_type.upper()}\")\n", | |
| " formatted.append(f\"Action:\\n{action.content}\")\n", | |
| " formatted.append(f\"Observation:\\n{action.observation}\")\n", | |
| " \n", | |
| " return '\\n'.join(formatted)\n", | |
| "\n", | |
| "formatted_trajectory = format_agentic_trajectory(env.trajectory)\n", | |
| "print(\"Formatted Trajectory for Training:\")\n", | |
| "print(formatted_trajectory)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### ForagerAgent Data Collection at Scale\n", | |
| "\n", | |
| "**How CWM generates agentic trajectories (Section 2.3):**\n", | |
| "\n", | |
| "1. **Build Executable Repository Images** (35k+ repos):\n", | |
| " - Use RepoAgent (LLM-backed setup)\n", | |
| " - Use Activ (GitHub Actions-based)\n", | |
| " - Create Docker containers that can run tests\n", | |
| "\n", | |
| "2. **Generate Tasks**:\n", | |
| " - **Mutate-Fix**: Inject synthetic bugs, agent must fix\n", | |
| " - **Issue-Fix**: Real GitHub issues and PRs\n", | |
| "\n", | |
| "3. **Collect Trajectories** (3M total):\n", | |
| " - Llama-3-70B-Instruct or Qwen3-235B-A22B as base\n", | |
| " - Agent interacts with Docker environments\n", | |
| " - Actions: create, edit, bash, submit\n", | |
| "\n", | |
| "4. **Post-Processing**:\n", | |
| " - Near-deduplication (MinHash + Jaccard similarity)\n", | |
| " - Stochastic loss masking (50% of observations)\n", | |
| " - Filter for quality\n", | |
| "\n", | |
| "5. **Self-Bootstrapping**:\n", | |
| " - Use earlier CWM iterations\n", | |
| " - Rejection sampling for high-quality traces\n", | |
| " - Include in subsequent training" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 5. Code Understanding Evaluation (CruxEval-style)\n", | |
| "\n", | |
| "**Reference:** Section 7.3 (Workflow 13)\n", | |
| "\n", | |
| "CWM is evaluated on its ability to predict Python execution traces using the CruxEval benchmark. This tests whether the model has learned \"world modeling\" - predicting how code will execute.\n", | |
| "\n", | |
| "Three modes:\n", | |
| "1. **Single-step**: Predict final output directly\n", | |
| "2. **Full line-by-line trace**: Predict complete execution trace\n", | |
| "3. **Natural language reasoning**: Explain execution in natural language" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Create simple code understanding test cases\n", | |
| "cruxeval_examples = [\n", | |
| " {\n", | |
| " 'code': '''def f(x):\n", | |
| " return x * 2\n", | |
| "\n", | |
| "result = f(5)''',\n", | |
| " 'expected_output': 10,\n", | |
| " 'description': 'Simple multiplication'\n", | |
| " },\n", | |
| " {\n", | |
| " 'code': '''def f(lst):\n", | |
| " return [x + 1 for x in lst]\n", | |
| "\n", | |
| "result = f([1, 2, 3])''',\n", | |
| " 'expected_output': [2, 3, 4],\n", | |
| " 'description': 'List comprehension'\n", | |
| " },\n", | |
| " {\n", | |
| " 'code': '''def f(s):\n", | |
| " return s.upper()[:3]\n", | |
| "\n", | |
| "result = f(\"hello\")''',\n", | |
| " 'expected_output': 'HEL',\n", | |
| " 'description': 'String manipulation'\n", | |
| " },\n", | |
| " {\n", | |
| " 'code': '''def f(n):\n", | |
| " total = 0\n", | |
| " for i in range(n):\n", | |
| " total += i\n", | |
| " return total\n", | |
| "\n", | |
| "result = f(4)''',\n", | |
| " 'expected_output': 6,\n", | |
| " 'description': 'Loop accumulation'\n", | |
| " },\n", | |
| "]\n", | |
| "\n", | |
| "print(f\"Created {len(cruxeval_examples)} CruxEval-style test cases\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class CodeUnderstandingEvaluator:\n", | |
| " \"\"\"Evaluator for code understanding tasks.\n", | |
| " \n", | |
| " In CWM evaluation (Section 7.3):\n", | |
| " - Tests on CruxEval benchmark\n", | |
| " - Compares single-step vs full trace prediction\n", | |
| " - Measures pass@1 accuracy\n", | |
| " - CWM achieves 87.0% on CruxEval Output\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def execute_and_predict(self, code: str):\n", | |
| " \"\"\"Execute code and capture result.\"\"\"\n", | |
| " try:\n", | |
| " namespace = {}\n", | |
| " exec(code, namespace)\n", | |
| " return namespace.get('result', None), None\n", | |
| " except Exception as e:\n", | |
| " return None, str(e)\n", | |
| " \n", | |
| " def trace_execution(self, code: str):\n", | |
| " \"\"\"Generate execution trace for code.\"\"\"\n", | |
| " tracer = PythonExecutionTracer()\n", | |
| " try:\n", | |
| " namespace = {}\n", | |
| " sys.settrace(tracer.trace_function)\n", | |
| " exec(code, namespace)\n", | |
| " sys.settrace(None)\n", | |
| " result = namespace.get('result', None)\n", | |
| " return result, tracer.trace\n", | |
| " except Exception as e:\n", | |
| " sys.settrace(None)\n", | |
| " return None, []\n", | |
| " \n", | |
| " def evaluate(self, examples: List[Dict]):\n", | |
| " \"\"\"Evaluate on code understanding examples.\"\"\"\n", | |
| " results = []\n", | |
| " \n", | |
| " for ex in examples:\n", | |
| " # Single-step prediction\n", | |
| " predicted, error = self.execute_and_predict(ex['code'])\n", | |
| " correct = predicted == ex['expected_output']\n", | |
| " \n", | |
| " # Trace prediction\n", | |
| " traced_result, trace = self.trace_execution(ex['code'])\n", | |
| " \n", | |
| " results.append({\n", | |
| " 'description': ex['description'],\n", | |
| " 'expected': ex['expected_output'],\n", | |
| " 'predicted': predicted,\n", | |
| " 'correct': correct,\n", | |
| " 'trace_length': len(trace),\n", | |
| " 'error': error\n", | |
| " })\n", | |
| " \n", | |
| " return results\n", | |
| "\n", | |
| "# Evaluate examples\n", | |
| "evaluator = CodeUnderstandingEvaluator()\n", | |
| "results = evaluator.evaluate(cruxeval_examples)\n", | |
| "\n", | |
| "# Display results\n", | |
| "print(\"Code Understanding Evaluation Results:\\n\")\n", | |
| "for i, r in enumerate(results):\n", | |
| " status = \"✓\" if r['correct'] else \"✗\"\n", | |
| " print(f\"{status} Example {i+1}: {r['description']}\")\n", | |
| " print(f\" Expected: {r['expected']}\")\n", | |
| " print(f\" Predicted: {r['predicted']}\")\n", | |
| " if r['error']:\n", | |
| " print(f\" Error: {r['error']}\")\n", | |
| " print(f\" Trace length: {r['trace_length']} steps\")\n", | |
| " print()\n", | |
| "\n", | |
| "# Calculate accuracy\n", | |
| "accuracy = sum(r['correct'] for r in results) / len(results) * 100\n", | |
| "print(f\"Overall Accuracy: {accuracy:.1f}%\")\n", | |
| "print(f\"\\nNote: CWM achieves 87.0% on full CruxEval Output benchmark\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Visualizing Execution Traces\n", | |
| "\n", | |
| "Let's visualize how execution traces help understand code:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Generate detailed trace for one example\n", | |
| "example_code = cruxeval_examples[3]['code'] # Loop example\n", | |
| "print(\"Analyzing code:\")\n", | |
| "print(example_code)\n", | |
| "print(\"\\n\" + \"=\"*50)\n", | |
| "\n", | |
| "result, trace = evaluator.trace_execution(example_code)\n", | |
| "\n", | |
| "print(f\"\\nExecution Trace ({len(trace)} steps):\\n\")\n", | |
| "for i, step in enumerate(trace[:15]): # Show first 15 steps\n", | |
| " if 'locals' in step:\n", | |
| " print(f\"Step {i}: Line {step.get('line', '?')}\")\n", | |
| " print(f\" Variables: {step['locals']}\")\n", | |
| " elif step.get('event') == 'return':\n", | |
| " print(f\"Step {i}: RETURN {step['value']}\")\n", | |
| "\n", | |
| "print(f\"\\nFinal result: {result}\")\n", | |
| "print(f\"Expected: {cruxeval_examples[3]['expected_output']}\")\n", | |
| "print(f\"Match: {result == cruxeval_examples[3]['expected_output']}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 6. CWM Training Pipeline Overview\n", | |
| "\n", | |
| "**Reference:** Section 4 and 5 (Workflow 3)\n", | |
| "\n", | |
| "Let's visualize the complete CWM training pipeline:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Training pipeline stages\n", | |
| "pipeline_stages = [\n", | |
| " {\n", | |
| " 'stage': 'Pre-training',\n", | |
| " 'tokens': '8T',\n", | |
| " 'context': '8k',\n", | |
| " 'data': 'General (30% code, STEM, general knowledge)',\n", | |
| " 'duration': 'Weeks',\n", | |
| " 'model_size': '32B parameters'\n", | |
| " },\n", | |
| " {\n", | |
| " 'stage': 'Mid-training',\n", | |
| " 'tokens': '5T',\n", | |
| " 'context': '131k',\n", | |
| " 'data': 'Python traces (120M+ functions) + ForagerAgent (3M trajectories) + Code',\n", | |
| " 'duration': 'Weeks',\n", | |
| " 'model_size': '32B parameters'\n", | |
| " },\n", | |
| " {\n", | |
| " 'stage': 'SFT',\n", | |
| " 'tokens': '100B',\n", | |
| " 'context': '131k',\n", | |
| " 'data': 'Instruction-following + reasoning traces + 30% rehearsal',\n", | |
| " 'duration': 'Days',\n", | |
| " 'model_size': '32B parameters'\n", | |
| " },\n", | |
| " {\n", | |
| " 'stage': 'RL',\n", | |
| " 'tokens': '172B',\n", | |
| " 'context': '131k',\n", | |
| " 'data': 'Multi-task (Code contests + Math + Agentic SWE + Agentic coding)',\n", | |
| " 'duration': 'Days',\n", | |
| " 'model_size': '32B parameters'\n", | |
| " },\n", | |
| "]\n", | |
| "\n", | |
| "# Create visualization if matplotlib available\n", | |
| "if HAS_SCIENTIFIC and plt:\n", | |
| " fig, ax = plt.subplots(figsize=(12, 6))\n", | |
| " \n", | |
| " stages = [s['stage'] for s in pipeline_stages]\n", | |
| " tokens = [float(s['tokens'].replace('T', '000').replace('B', '')) for s in pipeline_stages]\n", | |
| " \n", | |
| " colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']\n", | |
| " bars = ax.barh(stages, tokens, color=colors)\n", | |
| " \n", | |
| " ax.set_xlabel('Training Tokens (Billions)', fontsize=12)\n", | |
| " ax.set_title('CWM Training Pipeline', fontsize=14, fontweight='bold')\n", | |
| " ax.set_xscale('log')\n", | |
| " \n", | |
| " # Add labels\n", | |
| " for i, (bar, stage_info) in enumerate(zip(bars, pipeline_stages)):\n", | |
| " width = bar.get_width()\n", | |
| " ax.text(width * 1.1, bar.get_y() + bar.get_height()/2,\n", | |
| " f\"{stage_info['tokens']} tokens\\nContext: {stage_info['context']}\",\n", | |
| " ha='left', va='center', fontsize=9)\n", | |
| " \n", | |
| " plt.tight_layout()\n", | |
| " plt.show()\n", | |
| "else:\n", | |
| " print(\"Visualization requires matplotlib. Install with: pip install matplotlib\")\n", | |
| "\n", | |
| "# Print detailed information\n", | |
| "print(\"\\nCWM Training Pipeline Details:\\n\")\n", | |
| "for i, stage in enumerate(pipeline_stages):\n", | |
| " print(f\"{i+1}. {stage['stage']}:\")\n", | |
| " print(f\" - Tokens: {stage['tokens']}\")\n", | |
| " print(f\" - Context length: {stage['context']}\")\n", | |
| " print(f\" - Data: {stage['data']}\")\n", | |
| " print(f\" - Model size: {stage['model_size']}\")\n", | |
| " print()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Key Components of CWM Training\n", | |
| "\n", | |
| "**1. Data Generation Workflows:**\n", | |
| "- **Python Execution Traces**: 120M+ functions, repository tests, CodeContests solutions\n", | |
| "- **ForagerAgent Trajectories**: 3M multi-step interactions across 10.2k repos\n", | |
| "- **Executable Repository Images**: 35k+ Docker containers for testing\n", | |
| "\n", | |
| "**2. RL Environments:**\n", | |
| "- **Code Contests**: Competitive programming with unit test verification\n", | |
| "- **Agentic SWE**: Multi-turn software engineering (create, edit, bash, submit)\n", | |
| "- **Agentic Coding**: Combines reasoning with tool use\n", | |
| "- **Math**: Mathematical reasoning with verifiable answers\n", | |
| "\n", | |
| "**3. Training Algorithm:**\n", | |
| "- **GRPO** (Group Relative Policy Optimization)\n", | |
| "- PPO loss with Monte Carlo value estimation\n", | |
| "- No separate value model\n", | |
| "- Three-stage training with data resampling" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 7. Benchmark Results Summary\n", | |
| "\n", | |
| "**Reference:** Section 7\n", | |
| "\n", | |
| "CWM achieves state-of-the-art results on multiple benchmarks:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Benchmark results (from paper)\n", | |
| "benchmarks = {\n", | |
| " 'SWE-bench Verified': {\n", | |
| " 'CWM-32B': 44.6,\n", | |
| " 'Description': 'Software engineering tasks (500 problems)',\n", | |
| " 'Metric': 'Pass@1 (%)'\n", | |
| " },\n", | |
| " 'LiveCodeBench': {\n", | |
| " 'CWM-32B': 68.6,\n", | |
| " 'Description': 'Competitive programming',\n", | |
| " 'Metric': 'Pass@1 (%)'\n", | |
| " },\n", | |
| " 'AIME 2024': {\n", | |
| " 'CWM-32B': 16.7,\n", | |
| " 'Description': 'Mathematical reasoning (competition math)',\n", | |
| " 'Metric': 'Accuracy (%)'\n", | |
| " },\n", | |
| " 'CruxEval Output': {\n", | |
| " 'CWM-32B': 87.0,\n", | |
| " 'Description': 'Code execution prediction',\n", | |
| " 'Metric': 'Pass@1 (%)'\n", | |
| " },\n", | |
| " 'Math-500': {\n", | |
| " 'CWM-32B': 85.0,\n", | |
| " 'Description': 'Mathematical problem solving',\n", | |
| " 'Metric': 'Accuracy (%)'\n", | |
| " },\n", | |
| "}\n", | |
| "\n", | |
| "# Visualize results if matplotlib available\n", | |
| "if HAS_SCIENTIFIC and plt:\n", | |
| " fig, ax = plt.subplots(figsize=(10, 6))\n", | |
| " \n", | |
| " benchmark_names = list(benchmarks.keys())\n", | |
| " scores = [benchmarks[b]['CWM-32B'] for b in benchmark_names]\n", | |
| " \n", | |
| " bars = ax.bar(range(len(benchmark_names)), scores, color='steelblue', alpha=0.8)\n", | |
| " ax.set_xticks(range(len(benchmark_names)))\n", | |
| " ax.set_xticklabels(benchmark_names, rotation=45, ha='right')\n", | |
| " ax.set_ylabel('Score (%)', fontsize=12)\n", | |
| " ax.set_title('CWM-32B Benchmark Results', fontsize=14, fontweight='bold')\n", | |
| " ax.set_ylim(0, 100)\n", | |
| " ax.grid(axis='y', alpha=0.3)\n", | |
| " \n", | |
| " # Add value labels on bars\n", | |
| " for bar, score in zip(bars, scores):\n", | |
| " height = bar.get_height()\n", | |
| " ax.text(bar.get_x() + bar.get_width()/2, height + 2,\n", | |
| " f'{score:.1f}%', ha='center', va='bottom', fontsize=10, fontweight='bold')\n", | |
| " \n", | |
| " plt.tight_layout()\n", | |
| " plt.show()\n", | |
| "else:\n", | |
| " print(\"Visualization requires matplotlib. Install with: pip install matplotlib\\n\")\n", | |
| "\n", | |
| "# Print detailed results\n", | |
| "print(\"\\nDetailed Benchmark Results:\\n\")\n", | |
| "for name, data in benchmarks.items():\n", | |
| " print(f\"{name}:\")\n", | |
| " print(f\" Score: {data['CWM-32B']}% ({data['Metric']})\")\n", | |
| " print(f\" Description: {data['Description']}\")\n", | |
| " print()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 8. Scaling Guidance and Production Implementation\n", | |
| "\n", | |
| "### How to Scale These Workflows to Production\n", | |
| "\n", | |
| "This notebook demonstrates the core concepts with minimal examples. To replicate CWM's full results, you would need:\n", | |
| "\n", | |
| "#### 1. Infrastructure Requirements\n", | |
| "\n", | |
| "**Compute:**\n", | |
| "- Pre-training: Large GPU cluster (8T tokens requires weeks on hundreds of GPUs)\n", | |
| "- Mid-training: Similar scale (5T tokens with 131k context)\n", | |
| "- RL training: Distributed setup for asynchronous multi-environment training\n", | |
| "- Inference: Single 80GB H100 GPU (with FP8 quantization)\n", | |
| "\n", | |
| "**Storage:**\n", | |
| "- Training data: Terabytes for execution traces and trajectories\n", | |
| "- Docker images: 35k+ executable repository images\n", | |
| "- Model checkpoints: Multiple checkpoints at 32B parameters\n", | |
| "\n", | |
| "#### 2. Data Generation at Scale\n", | |
| "\n", | |
| "**Python Execution Traces:**\n", | |
| "```python\n", | |
| "# Instead of 5 examples, generate 120M+ function traces\n", | |
| "# - Collect Python functions from GitHub, Stack Overflow, etc.\n", | |
| "# - Use fuzzing + LLM prompting for input generation\n", | |
| "# - Run with proper timeout and sandboxing\n", | |
| "# - Post-process to observation-action format\n", | |
| "# - Generate repository-level traces from 21k+ repos\n", | |
| "```\n", | |
| "\n", | |
| "**ForagerAgent Trajectories:**\n", | |
| "```python\n", | |
| "# Instead of 1 trajectory, generate 3M trajectories\n", | |
| "# - Build 35k+ executable repository Docker images\n", | |
| "# - Generate mutate-fix and issue-fix tasks\n", | |
| "# - Run Llama-3-70B or Qwen3-235B agents\n", | |
| "# - Collect multi-step interactions\n", | |
| "# - Apply near-deduplication and filtering\n", | |
| "```\n", | |
| "\n", | |
| "#### 3. Model Training\n", | |
| "\n", | |
| "**Pre-training:**\n", | |
| "```python\n", | |
| "# Train 32B parameter dense Transformer\n", | |
| "# - 8T tokens from diverse sources (30% code)\n", | |
| "# - 8k context length\n", | |
| "# - Use scaling laws to set hyperparameters\n", | |
| "# - Distributed training with FSDP or similar\n", | |
| "```\n", | |
| "\n", | |
| "**Mid-training:**\n", | |
| "```python\n", | |
| "# Continue training with code world modeling data\n", | |
| "# - 5T additional tokens\n", | |
| "# - Increase context to 131k (alternating local-global attention)\n", | |
| "# - Mix: Python traces + ForagerAgent + code data\n", | |
| "# - Proper data balancing and rehearsal\n", | |
| "```\n", | |
| "\n", | |
| "**RL Training:**\n", | |
| "```python\n", | |
| "# Multi-task RL with GRPO\n", | |
| "# - 172B tokens across 4 environments\n", | |
| "# - Asynchronous distributed training\n", | |
| "# - Three-stage training with data resampling\n", | |
| "# - Length reward scheduling\n", | |
| "# - Rehearsal from SFT datamix\n", | |
| "```\n", | |
| "\n", | |
| "#### 4. Evaluation Infrastructure\n", | |
| "\n", | |
| "**Benchmarks to implement:**\n", | |
| "- SWE-bench Verified (500 problems, Docker environments)\n", | |
| "- LiveCodeBench (competitive programming)\n", | |
| "- CruxEval (execution prediction)\n", | |
| "- AIME 2024, Math-500 (mathematical reasoning)\n", | |
| "- BigOBench (complexity prediction)\n", | |
| "- HaltEval (termination prediction)\n", | |
| "\n", | |
| "**Agentic harnesses:**\n", | |
| "- Primary harness with create/edit/bash/submit tools\n", | |
| "- Alternative harnesses (Mini-SWE-Agent, OpenHands) for robustness testing\n", | |
| "\n", | |
| "#### 5. Key Implementation Details\n", | |
| "\n", | |
| "**Docker-based execution:**\n", | |
| "- Use RepoAgent (LLM-backed) or Activ (GitHub Actions-based)\n", | |
| "- Isolate execution environments\n", | |
| "- Handle timeouts and resource limits\n", | |
| "\n", | |
| "**Special tokens:**\n", | |
| "- `<|observation_sep|>`, `<|action_sep|>`, `<|return_sep|>` for traces\n", | |
| "- Reasoning tokens for thought processes\n", | |
| "- Tool-calling format for agentic actions\n", | |
| "\n", | |
| "**Training optimizations:**\n", | |
| "- Stochastic loss masking (50% for observations)\n", | |
| "- Near-deduplication (MinHash, Jaccard similarity)\n", | |
| "- Self-bootstrapping (rejection sampling from earlier iterations)\n", | |
| "- Hybrid rewards (test results + patch similarity)\n", | |
| "\n", | |
| "### Expected Timeline\n", | |
| "\n", | |
| "- **Data collection**: Weeks to months (parallelizable)\n", | |
| "- **Pre-training**: Weeks on large cluster\n", | |
| "- **Mid-training**: Weeks on large cluster\n", | |
| "- **SFT**: Days\n", | |
| "- **RL**: Days to weeks\n", | |
| "\n", | |
| "### Cost Estimation\n", | |
| "\n", | |
| "Training a 32B parameter model on 8T+5T+0.1T+0.172T ≈ 13.3T tokens requires:\n", | |
| "- Millions of GPU-hours\n", | |
| "- Estimated cost: $1M-$10M+ depending on infrastructure\n", | |
| "\n", | |
| "### Open Weights\n", | |
| "\n", | |
| "The paper mentions CWM will be released as open-weights, allowing researchers to:\n", | |
| "- Fine-tune on domain-specific tasks\n", | |
| "- Study world modeling capabilities\n", | |
| "- Build agentic applications\n", | |
| "- Improve code generation systems" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 9. Conclusions and Key Takeaways\n", | |
| "\n", | |
| "### Main Contributions of CWM\n", | |
| "\n", | |
| "1. **World Modeling for Code**: Training on execution traces enables the model to predict how code will execute before running it\n", | |
| "\n", | |
| "2. **Agentic Training Data**: 3M ForagerAgent trajectories teach the model multi-step problem solving with tools\n", | |
| "\n", | |
| "3. **Multi-Task RL**: Joint training across code, math, and software engineering tasks improves generalization\n", | |
| "\n", | |
| "4. **Executable Repository Images**: 35k+ Docker containers enable large-scale code execution and testing\n", | |
| "\n", | |
| "5. **State-of-the-Art Results**: Achieves top performance on SWE-bench Verified (44.6%), LiveCodeBench (68.6%), and other benchmarks\n", | |
| "\n", | |
| "### Why World Modeling Matters\n", | |
| "\n", | |
| "- Enables the model to reason about code execution without running it\n", | |
| "- Helps predict bugs and edge cases\n", | |
| "- Improves code generation quality\n", | |
| "- Supports better test case generation\n", | |
| "- Facilitates debugging and error correction\n", | |
| "\n", | |
| "### Future Research Directions\n", | |
| "\n", | |
| "- Scaling to larger models (>100B parameters)\n", | |
| "- More programming languages beyond Python\n", | |
| "- Longer-context understanding (>128k tokens)\n", | |
| "- Better termination prediction (halting problem)\n", | |
| "- Integration with formal verification\n", | |
| "- Multi-modal code understanding (code + documentation + UI)\n", | |
| "\n", | |
| "### Using This Notebook\n", | |
| "\n", | |
| "This notebook provides:\n", | |
| "- ✓ Working implementations of core concepts\n", | |
| "- ✓ Educational demonstrations within resource constraints\n", | |
| "- ✓ Clear guidance for scaling to production\n", | |
| "- ✓ Understanding of CWM's methodology\n", | |
| "\n", | |
| "To build on this work:\n", | |
| "1. Start with the data generation workflows\n", | |
| "2. Build executable repository infrastructure\n", | |
| "3. Collect training data at scale\n", | |
| "4. Train models incrementally (small → large)\n", | |
| "5. Evaluate thoroughly on benchmarks\n", | |
| "\n", | |
| "### References\n", | |
| "\n", | |
| "**Paper:** \"CWM: An Open-Weights LLM for Research on Code Generation with World Models\"\n", | |
| "\n", | |
| "**Key Sections:**\n", | |
| "- Section 2.2: Python Execution Traces\n", | |
| "- Section 2.3: ForagerAgent Trajectories\n", | |
| "- Section 4: Training Pipeline\n", | |
| "- Section 5: RL Training\n", | |
| "- Section 7: Evaluation\n", | |
| "\n", | |
| "For the complete methodology and results, please refer to the original paper." | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.0" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment