Skip to content

Instantly share code, notes, and snippets.

@rouseguy
Last active December 13, 2025 12:11
Show Gist options
  • Select an option

  • Save rouseguy/1440345c7d9d8a5707fc9b15bb40d01a to your computer and use it in GitHub Desktop.

Select an option

Save rouseguy/1440345c7d9d8a5707fc9b15bb40d01a to your computer and use it in GitHub Desktop.
GRPO training script on Gemma3-1B-IT
# GRPO Training Script for Gemma3-1B-Instruct
# Run in notebook cells
# %% Cell 1: Installs
# !pip install trl transformers datasets accelerate peft bitsandbytes
# !pip install scikit-learn sentencepiece
# %% Cell 2: Imports
import re
import torch
import numpy as np
from typing import Optional
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
# %% Cell 3: Configuration
MODEL_ID = "google/gemma-3-1b-it" # Base SFT model
OUTPUT_DIR = "./grpo_gemma3_output"
# Reward weights for combining multiple rewards
REWARD_WEIGHTS = {
"reasoning_length": 0.3,
"answer_match": 0.5,
"format_check": 0.2,
}
# Target reasoning length (tokens/chars) - adjust based on your task
TARGET_REASONING_LENGTH = 500
MAX_REASONING_LENGTH = 1500
# %% Cell 4: Helper Functions
def extract_tag_content(text: str, tag_name: str) -> Optional[str]:
"""Extract content between XML-style tags."""
pattern = rf"<{tag_name}>(.*?)</{tag_name}>"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else None
def has_tag(text: str, tag_name: str) -> tuple[bool, bool]:
"""Check for presence of opening and closing tags."""
has_open = bool(re.search(rf"<{tag_name}>", text, re.IGNORECASE))
has_close = bool(re.search(rf"</{tag_name}>", text, re.IGNORECASE))
return has_open, has_close
def is_numeric(value: str) -> bool:
"""Check if string can be parsed as a number."""
try:
float(value.replace(",", "").strip())
return True
except (ValueError, AttributeError):
return False
def parse_numeric(value: str) -> float:
"""Parse string to float, handling commas."""
return float(value.replace(",", "").strip())
def compute_text_similarity(text1: str, text2: str) -> float:
"""Compute cosine similarity between two texts using TF-IDF."""
if not text1 or not text2:
return 0.0
try:
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform([text1.lower(), text2.lower()])
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
return float(similarity)
except Exception:
# Fallback: exact match check
return 1.0 if text1.lower().strip() == text2.lower().strip() else 0.0
# %% Cell 5: Reward Functions
def reasoning_length_reward(completions: list[str], **kwargs) -> list[float]:
"""
Reward based on reasoning length.
Encourages detailed reasoning up to target length, penalizes very short or excessive.
"""
rewards = []
for completion in completions:
reasoning = extract_tag_content(completion, "reasoning")
if reasoning is None:
rewards.append(-1.0) # No reasoning tag found
continue
length = len(reasoning)
if length < 50:
# Too short - penalize
reward = -0.5
elif length <= TARGET_REASONING_LENGTH:
# Reward scales with length up to target
reward = length / TARGET_REASONING_LENGTH
elif length <= MAX_REASONING_LENGTH:
# Slightly reduce reward for going over target
reward = 1.0 - 0.3 * (length - TARGET_REASONING_LENGTH) / (MAX_REASONING_LENGTH - TARGET_REASONING_LENGTH)
else:
# Penalize excessive length
reward = 0.5 - 0.5 * min(1.0, (length - MAX_REASONING_LENGTH) / MAX_REASONING_LENGTH)
rewards.append(reward)
return rewards
def answer_match_reward(
completions: list[str],
ground_truth: list[str],
**kwargs
) -> list[float]:
"""
Reward based on answer correctness.
- Numeric: ratio-based (lower ratio = better, capped reward)
- Text: cosine similarity
"""
rewards = []
for completion, gt in zip(completions, ground_truth):
predicted_answer = extract_tag_content(completion, "answer")
if predicted_answer is None:
rewards.append(-1.0)
continue
pred_clean = predicted_answer.strip()
gt_clean = gt.strip()
# Both numeric: use ratio
if is_numeric(pred_clean) and is_numeric(gt_clean):
pred_val = parse_numeric(pred_clean)
gt_val = parse_numeric(gt_clean)
if gt_val == 0:
# Handle division by zero
reward = 1.0 if pred_val == 0 else -1.0
else:
ratio = abs(pred_val / gt_val)
# Perfect match = 1.0, ratio of 2 or 0.5 = 0.0, worse = negative
if 0.99 <= ratio <= 1.01:
reward = 1.0 # Near-exact match
elif ratio > 1:
reward = max(-1.0, 2.0 - ratio) # Overestimate penalty
else:
reward = max(-1.0, ratio * 2 - 1) # Underestimate penalty
else:
# Text comparison: cosine similarity
reward = compute_text_similarity(pred_clean, gt_clean)
# Scale from [0,1] to [-0.5, 1.0] to penalize low similarity
reward = reward * 1.5 - 0.5
rewards.append(reward)
return rewards
def format_check_reward(completions: list[str], **kwargs) -> list[float]:
"""
Reward for proper XML tag structure.
Checks for both opening and closing tags for 'reasoning' and 'answer'.
"""
rewards = []
for completion in completions:
score = 0.0
# Check reasoning tags (0.5 total)
reasoning_open, reasoning_close = has_tag(completion, "reasoning")
if reasoning_open:
score += 0.25
if reasoning_close:
score += 0.25
# Check answer tags (0.5 total)
answer_open, answer_close = has_tag(completion, "answer")
if answer_open:
score += 0.25
if answer_close:
score += 0.25
# Bonus for proper nesting (all tags present)
if all([reasoning_open, reasoning_close, answer_open, answer_close]):
score += 0.2
# Normalize to [-1, 1] range
reward = score * 2 - 1 # 0->-1, 0.5->0, 1.0->1, 1.2->1.4 (capped to 1)
rewards.append(min(1.0, reward))
return rewards
# %% Cell 6: Combined Reward Function
def combined_reward_function(
completions: list[str],
ground_truth: list[str],
**kwargs
) -> list[float]:
"""Combine all reward signals with weights."""
r_length = reasoning_length_reward(completions)
r_answer = answer_match_reward(completions, ground_truth=ground_truth)
r_format = format_check_reward(completions)
combined = []
for i in range(len(completions)):
reward = (
REWARD_WEIGHTS["reasoning_length"] * r_length[i] +
REWARD_WEIGHTS["answer_match"] * r_answer[i] +
REWARD_WEIGHTS["format_check"] * r_format[i]
)
combined.append(reward)
return combined
# %% Cell 7: Dataset Preparation
def prepare_dataset(dataset_name: str = "gsm8k", split: str = "train"):
"""
Load and prepare dataset. Adjust based on your actual dataset.
Expected format: each example should have 'question' and 'answer' fields.
"""
# Example with GSM8K - modify for your dataset
dataset = load_dataset(dataset_name, "main", split=split)
def format_example(example):
# Format prompt for the model
prompt = f"""Solve the following problem. Show your reasoning in <reasoning></reasoning> tags, then provide the final answer in <answer></answer> tags.
Problem: {example['question']}"""
# Extract ground truth answer (GSM8K specific - adjust for your data)
gt_answer = example['answer'].split("####")[-1].strip() if "####" in example['answer'] else example['answer']
return {
"prompt": prompt,
"ground_truth": gt_answer,
}
dataset = dataset.map(format_example, remove_columns=dataset.column_names)
return dataset
# %% Cell 8: Model and Tokenizer Setup
def setup_model_and_tokenizer():
"""Initialize model with quantization and LoRA for efficient training."""
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager", # Use "flash_attention_2" if available
)
return model, tokenizer
# %% Cell 9: Training Configuration
def get_training_config():
"""GRPO training configuration."""
# LoRA config for PEFT
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM",
)
# GRPO specific config
training_config = GRPOConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=1e-5,
max_completion_length=1024, # Max tokens for generated completion
num_generations=4, # Number of completions per prompt for GRPO
logging_steps=10,
save_steps=100,
bf16=True,
gradient_checkpointing=True,
optim="adamw_8bit",
warmup_ratio=0.1,
report_to="none", # Set to "wandb" if using weights & biases
)
return training_config, peft_config
# %% Cell 10: Main Training Loop
def train():
"""Main training function."""
print("Loading dataset...")
dataset = prepare_dataset()
print(f"Dataset size: {len(dataset)}")
print("Setting up model and tokenizer...")
model, tokenizer = setup_model_and_tokenizer()
print("Configuring training...")
training_config, peft_config = get_training_config()
print("Initializing GRPO Trainer...")
trainer = GRPOTrainer(
model=model,
args=training_config,
train_dataset=dataset,
processing_class=tokenizer,
peft_config=peft_config,
reward_funcs=[
reasoning_length_reward,
answer_match_reward,
format_check_reward,
],
# Alternatively, use the combined function:
# reward_funcs=combined_reward_function,
)
print("Starting training...")
trainer.train()
print("Saving model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Training complete! Model saved to {OUTPUT_DIR}")
return trainer
# %% Cell 11: Run Training
if __name__ == "__main__":
trainer = train()
# %% Cell 12: Quick Inference Test (Optional)
def test_inference(prompt: str, model_path: str = OUTPUT_DIR):
"""Test the trained model."""
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(model_path)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(base_model, model_path)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.7,
do_sample=True,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Example usage:
# test_prompt = """Solve the following problem. Show your reasoning in <reasoning></reasoning> tags, then provide the final answer in <answer></answer> tags.
#
# Problem: If a train travels 120 miles in 2 hours, what is its average speed?"""
# print(test_inference(test_prompt))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment