Last active
December 13, 2025 12:11
-
-
Save rouseguy/1440345c7d9d8a5707fc9b15bb40d01a to your computer and use it in GitHub Desktop.
GRPO training script on Gemma3-1B-IT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # GRPO Training Script for Gemma3-1B-Instruct | |
| # Run in notebook cells | |
| # %% Cell 1: Installs | |
| # !pip install trl transformers datasets accelerate peft bitsandbytes | |
| # !pip install scikit-learn sentencepiece | |
| # %% Cell 2: Imports | |
| import re | |
| import torch | |
| import numpy as np | |
| from typing import Optional | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import LoraConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| # %% Cell 3: Configuration | |
| MODEL_ID = "google/gemma-3-1b-it" # Base SFT model | |
| OUTPUT_DIR = "./grpo_gemma3_output" | |
| # Reward weights for combining multiple rewards | |
| REWARD_WEIGHTS = { | |
| "reasoning_length": 0.3, | |
| "answer_match": 0.5, | |
| "format_check": 0.2, | |
| } | |
| # Target reasoning length (tokens/chars) - adjust based on your task | |
| TARGET_REASONING_LENGTH = 500 | |
| MAX_REASONING_LENGTH = 1500 | |
| # %% Cell 4: Helper Functions | |
| def extract_tag_content(text: str, tag_name: str) -> Optional[str]: | |
| """Extract content between XML-style tags.""" | |
| pattern = rf"<{tag_name}>(.*?)</{tag_name}>" | |
| match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
| return match.group(1).strip() if match else None | |
| def has_tag(text: str, tag_name: str) -> tuple[bool, bool]: | |
| """Check for presence of opening and closing tags.""" | |
| has_open = bool(re.search(rf"<{tag_name}>", text, re.IGNORECASE)) | |
| has_close = bool(re.search(rf"</{tag_name}>", text, re.IGNORECASE)) | |
| return has_open, has_close | |
| def is_numeric(value: str) -> bool: | |
| """Check if string can be parsed as a number.""" | |
| try: | |
| float(value.replace(",", "").strip()) | |
| return True | |
| except (ValueError, AttributeError): | |
| return False | |
| def parse_numeric(value: str) -> float: | |
| """Parse string to float, handling commas.""" | |
| return float(value.replace(",", "").strip()) | |
| def compute_text_similarity(text1: str, text2: str) -> float: | |
| """Compute cosine similarity between two texts using TF-IDF.""" | |
| if not text1 or not text2: | |
| return 0.0 | |
| try: | |
| vectorizer = TfidfVectorizer() | |
| tfidf_matrix = vectorizer.fit_transform([text1.lower(), text2.lower()]) | |
| similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0] | |
| return float(similarity) | |
| except Exception: | |
| # Fallback: exact match check | |
| return 1.0 if text1.lower().strip() == text2.lower().strip() else 0.0 | |
| # %% Cell 5: Reward Functions | |
| def reasoning_length_reward(completions: list[str], **kwargs) -> list[float]: | |
| """ | |
| Reward based on reasoning length. | |
| Encourages detailed reasoning up to target length, penalizes very short or excessive. | |
| """ | |
| rewards = [] | |
| for completion in completions: | |
| reasoning = extract_tag_content(completion, "reasoning") | |
| if reasoning is None: | |
| rewards.append(-1.0) # No reasoning tag found | |
| continue | |
| length = len(reasoning) | |
| if length < 50: | |
| # Too short - penalize | |
| reward = -0.5 | |
| elif length <= TARGET_REASONING_LENGTH: | |
| # Reward scales with length up to target | |
| reward = length / TARGET_REASONING_LENGTH | |
| elif length <= MAX_REASONING_LENGTH: | |
| # Slightly reduce reward for going over target | |
| reward = 1.0 - 0.3 * (length - TARGET_REASONING_LENGTH) / (MAX_REASONING_LENGTH - TARGET_REASONING_LENGTH) | |
| else: | |
| # Penalize excessive length | |
| reward = 0.5 - 0.5 * min(1.0, (length - MAX_REASONING_LENGTH) / MAX_REASONING_LENGTH) | |
| rewards.append(reward) | |
| return rewards | |
| def answer_match_reward( | |
| completions: list[str], | |
| ground_truth: list[str], | |
| **kwargs | |
| ) -> list[float]: | |
| """ | |
| Reward based on answer correctness. | |
| - Numeric: ratio-based (lower ratio = better, capped reward) | |
| - Text: cosine similarity | |
| """ | |
| rewards = [] | |
| for completion, gt in zip(completions, ground_truth): | |
| predicted_answer = extract_tag_content(completion, "answer") | |
| if predicted_answer is None: | |
| rewards.append(-1.0) | |
| continue | |
| pred_clean = predicted_answer.strip() | |
| gt_clean = gt.strip() | |
| # Both numeric: use ratio | |
| if is_numeric(pred_clean) and is_numeric(gt_clean): | |
| pred_val = parse_numeric(pred_clean) | |
| gt_val = parse_numeric(gt_clean) | |
| if gt_val == 0: | |
| # Handle division by zero | |
| reward = 1.0 if pred_val == 0 else -1.0 | |
| else: | |
| ratio = abs(pred_val / gt_val) | |
| # Perfect match = 1.0, ratio of 2 or 0.5 = 0.0, worse = negative | |
| if 0.99 <= ratio <= 1.01: | |
| reward = 1.0 # Near-exact match | |
| elif ratio > 1: | |
| reward = max(-1.0, 2.0 - ratio) # Overestimate penalty | |
| else: | |
| reward = max(-1.0, ratio * 2 - 1) # Underestimate penalty | |
| else: | |
| # Text comparison: cosine similarity | |
| reward = compute_text_similarity(pred_clean, gt_clean) | |
| # Scale from [0,1] to [-0.5, 1.0] to penalize low similarity | |
| reward = reward * 1.5 - 0.5 | |
| rewards.append(reward) | |
| return rewards | |
| def format_check_reward(completions: list[str], **kwargs) -> list[float]: | |
| """ | |
| Reward for proper XML tag structure. | |
| Checks for both opening and closing tags for 'reasoning' and 'answer'. | |
| """ | |
| rewards = [] | |
| for completion in completions: | |
| score = 0.0 | |
| # Check reasoning tags (0.5 total) | |
| reasoning_open, reasoning_close = has_tag(completion, "reasoning") | |
| if reasoning_open: | |
| score += 0.25 | |
| if reasoning_close: | |
| score += 0.25 | |
| # Check answer tags (0.5 total) | |
| answer_open, answer_close = has_tag(completion, "answer") | |
| if answer_open: | |
| score += 0.25 | |
| if answer_close: | |
| score += 0.25 | |
| # Bonus for proper nesting (all tags present) | |
| if all([reasoning_open, reasoning_close, answer_open, answer_close]): | |
| score += 0.2 | |
| # Normalize to [-1, 1] range | |
| reward = score * 2 - 1 # 0->-1, 0.5->0, 1.0->1, 1.2->1.4 (capped to 1) | |
| rewards.append(min(1.0, reward)) | |
| return rewards | |
| # %% Cell 6: Combined Reward Function | |
| def combined_reward_function( | |
| completions: list[str], | |
| ground_truth: list[str], | |
| **kwargs | |
| ) -> list[float]: | |
| """Combine all reward signals with weights.""" | |
| r_length = reasoning_length_reward(completions) | |
| r_answer = answer_match_reward(completions, ground_truth=ground_truth) | |
| r_format = format_check_reward(completions) | |
| combined = [] | |
| for i in range(len(completions)): | |
| reward = ( | |
| REWARD_WEIGHTS["reasoning_length"] * r_length[i] + | |
| REWARD_WEIGHTS["answer_match"] * r_answer[i] + | |
| REWARD_WEIGHTS["format_check"] * r_format[i] | |
| ) | |
| combined.append(reward) | |
| return combined | |
| # %% Cell 7: Dataset Preparation | |
| def prepare_dataset(dataset_name: str = "gsm8k", split: str = "train"): | |
| """ | |
| Load and prepare dataset. Adjust based on your actual dataset. | |
| Expected format: each example should have 'question' and 'answer' fields. | |
| """ | |
| # Example with GSM8K - modify for your dataset | |
| dataset = load_dataset(dataset_name, "main", split=split) | |
| def format_example(example): | |
| # Format prompt for the model | |
| prompt = f"""Solve the following problem. Show your reasoning in <reasoning></reasoning> tags, then provide the final answer in <answer></answer> tags. | |
| Problem: {example['question']}""" | |
| # Extract ground truth answer (GSM8K specific - adjust for your data) | |
| gt_answer = example['answer'].split("####")[-1].strip() if "####" in example['answer'] else example['answer'] | |
| return { | |
| "prompt": prompt, | |
| "ground_truth": gt_answer, | |
| } | |
| dataset = dataset.map(format_example, remove_columns=dataset.column_names) | |
| return dataset | |
| # %% Cell 8: Model and Tokenizer Setup | |
| def setup_model_and_tokenizer(): | |
| """Initialize model with quantization and LoRA for efficient training.""" | |
| # 4-bit quantization config | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="eager", # Use "flash_attention_2" if available | |
| ) | |
| return model, tokenizer | |
| # %% Cell 9: Training Configuration | |
| def get_training_config(): | |
| """GRPO training configuration.""" | |
| # LoRA config for PEFT | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| task_type="CAUSAL_LM", | |
| ) | |
| # GRPO specific config | |
| training_config = GRPOConfig( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=3, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| learning_rate=1e-5, | |
| max_completion_length=1024, # Max tokens for generated completion | |
| num_generations=4, # Number of completions per prompt for GRPO | |
| logging_steps=10, | |
| save_steps=100, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| optim="adamw_8bit", | |
| warmup_ratio=0.1, | |
| report_to="none", # Set to "wandb" if using weights & biases | |
| ) | |
| return training_config, peft_config | |
| # %% Cell 10: Main Training Loop | |
| def train(): | |
| """Main training function.""" | |
| print("Loading dataset...") | |
| dataset = prepare_dataset() | |
| print(f"Dataset size: {len(dataset)}") | |
| print("Setting up model and tokenizer...") | |
| model, tokenizer = setup_model_and_tokenizer() | |
| print("Configuring training...") | |
| training_config, peft_config = get_training_config() | |
| print("Initializing GRPO Trainer...") | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_config, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| reward_funcs=[ | |
| reasoning_length_reward, | |
| answer_match_reward, | |
| format_check_reward, | |
| ], | |
| # Alternatively, use the combined function: | |
| # reward_funcs=combined_reward_function, | |
| ) | |
| print("Starting training...") | |
| trainer.train() | |
| print("Saving model...") | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| print(f"Training complete! Model saved to {OUTPUT_DIR}") | |
| return trainer | |
| # %% Cell 11: Run Training | |
| if __name__ == "__main__": | |
| trainer = train() | |
| # %% Cell 12: Quick Inference Test (Optional) | |
| def test_inference(prompt: str, model_path: str = OUTPUT_DIR): | |
| """Test the trained model.""" | |
| from peft import PeftModel | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, model_path) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.7, | |
| do_sample=True, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # Example usage: | |
| # test_prompt = """Solve the following problem. Show your reasoning in <reasoning></reasoning> tags, then provide the final answer in <answer></answer> tags. | |
| # | |
| # Problem: If a train travels 120 miles in 2 hours, what is its average speed?""" | |
| # print(test_inference(test_prompt)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment