Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save lmassaron/884c39c2b063882fd33af4046212f476 to your computer and use it in GitHub Desktop.

Select an option

Save lmassaron/884c39c2b063882fd33af4046212f476 to your computer and use it in GitHub Desktop.
fine-tune-gemma-3-270m-for-sentiment-analysis.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/lmassaron/884c39c2b063882fd33af4046212f476/fine-tune-gemma-3-270m-for-sentiment-analysis.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "8b3d2e26",
"metadata": {
"papermill": {
"duration": 0.020687,
"end_time": "2025-09-11T10:55:58.198718",
"exception": false,
"start_time": "2025-09-11T10:55:58.178031",
"status": "completed"
},
"tags": [],
"id": "8b3d2e26"
},
"source": [
"## Fine-tune Gemma 3 270M-it for Sentiment Analysis\n",
"\n",
"This notebook provides a hands-on tutorial for fine-tuning the Gemma 3 270M-it model for sentiment analysis on financial and economic information. Analyzing sentiment in this domain is crucial for businesses to gain market insights, manage risks, and inform investment decisions.\n",
"\n",
"To demonstrate the fine-tuning process, we use the FinancialPhraseBank dataset. This dataset is particularly valuable because, within the realm of finance and economic texts, annotated datasets are notably rare, with many being exclusively reserved for proprietary purposes. To address the issue of insufficient training data, scholars from the Aalto University School of Business introduced in 2014 a set of approximately 5000 sentences. This collection aimed to establish human-annotated benchmarks, serving as a standard for evaluating alternative modeling techniques. The involved annotators (16 individuals with adequate background knowledge of financial markets) were instructed to assess the sentences solely from the perspective of an investor, evaluating whether the news potentially has a positive, negative, or neutral impact on the stock price.\n",
"\n",
"The FinancialPhraseBank dataset is a comprehensive collection that captures the sentiments of financial news headlines from the viewpoint of a retail investor. Comprising two key columns, namely \"Sentiment\" and \"News Headline,\" the dataset effectively classifies sentiments as either negative, neutral, or positive. This structured dataset serves as a valuable resource for analyzing and understanding the complex dynamics of sentiment in the financial news domain. It has been utilized in various studies and research initiatives since its inception, as noted in the work by Malo, P., Sinha, A., Korhonen, P., Wallenius, J., and Takala, P. \"Good debt or bad debt: Detecting semantic orientations in economic texts.\", published in the Journal of the Association for Information Science and Technology in 2014."
]
},
{
"cell_type": "markdown",
"id": "24525812",
"metadata": {
"papermill": {
"duration": 0.018415,
"end_time": "2025-09-11T10:55:58.236940",
"exception": false,
"start_time": "2025-09-11T10:55:58.218525",
"status": "completed"
},
"tags": [],
"id": "24525812"
},
"source": [
"### 1. Setup Environment\n",
"\n",
"First, we install the required libraries.\n",
"* accelerate: A library by Hugging Face for efficient PyTorch training on any hardware configuration, including multi-GPU setups.\n",
"* peft: A library for Parameter-Efficient Fine-Tuning. It enables us to adapt pre-trained models by fine-tuning only a small fraction of their parameters, thereby significantly reducing computational costs.\n",
"* trl: A library by Hugging Face for training transformer language models with techniques like Supervised Fine-tuning (SFT), which we will use here.\n",
"* flash-attn: An optional library that provides a highly optimized attention mechanism, which can speed up training and reduce memory usage on compatible GPUs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95fc11e6",
"metadata": {
"id": "95fc11e6"
},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73085e50",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:55:58.274934Z",
"iopub.status.busy": "2025-09-11T10:55:58.274647Z",
"iopub.status.idle": "2025-09-11T10:57:29.890851Z",
"shell.execute_reply": "2025-09-11T10:57:29.889862Z"
},
"papermill": {
"duration": 91.636492,
"end_time": "2025-09-11T10:57:29.892294",
"exception": false,
"start_time": "2025-09-11T10:55:58.255802",
"status": "completed"
},
"tags": [],
"id": "73085e50"
},
"outputs": [],
"source": [
"# Install Pytorch & other libraries\n",
"%pip -q install torch tensorboard\n",
"\n",
"# Install Hugging Face libraries\n",
"%pip -q install transformers datasets accelerate evaluate trl protobuf sentencepiece"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffbb5e56",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:57:29.968713Z",
"iopub.status.busy": "2025-09-11T10:57:29.968082Z",
"iopub.status.idle": "2025-09-11T10:57:33.566041Z",
"shell.execute_reply": "2025-09-11T10:57:33.565216Z"
},
"papermill": {
"duration": 3.63737,
"end_time": "2025-09-11T10:57:33.567211",
"exception": false,
"start_time": "2025-09-11T10:57:29.929841",
"status": "completed"
},
"tags": [],
"id": "ffbb5e56"
},
"outputs": [],
"source": [
"import subprocess\n",
"import sys\n",
"import torch\n",
"\n",
"def install_flash_attn_conditionally():\n",
" \"\"\"\n",
" Checks the GPU's compute capability and installs the appropriate version of flash-attn.\n",
" \"\"\"\n",
" if not torch.cuda.is_available():\n",
" print(\"No CUDA-enabled GPU found. Skipping flash-attn installation.\")\n",
" return\n",
"\n",
" try:\n",
" # Get the compute capability of the first available GPU\n",
" major, minor = torch.cuda.get_device_capability(0)\n",
" compute_capability = float(f\"{major}.{minor}\")\n",
" gpu_name = torch.cuda.get_device_name(0)\n",
" print(f\"Found GPU: {gpu_name} with Compute Capability: {compute_capability}\")\n",
"\n",
" # Check for Ampere, Ada, Hopper, or newer architectures (for FlashAttention 2)\n",
" if compute_capability >= 8.0:\n",
" # Ampere, Ada, and Hopper architectures support bfloat16 and are ideal for FlashAttention 2\n",
" is_bf16_supported = torch.cuda.is_bf16_supported()\n",
" if is_bf16_supported:\n",
" print(\"GPU supports BF16 and is compatible with FlashAttention 2.\")\n",
" print(\"Proceeding with installation of the latest 'flash-attn'...\")\n",
" # Install the latest version of flash-attn\n",
" install_package(\"flash-attn\", \"-q\")\n",
" return True\n",
" else:\n",
" print(\"GPU architecture is compatible, but BF16 is not supported. Skipping installation.\")\n",
" return False\n",
" # Check for Turing architecture (for original FlashAttention)\n",
" elif compute_capability == 7.5:\n",
" print(\"Turing architecture GPU detected. Compatible with original FlashAttention (v1.x).\")\n",
" print(\"Proceeding with installation of 'flash-attn ...\")\n",
" # Install a specific version of flash-attn compatible with Turing\n",
" install_package(\"flash-attn\", \"-q\")\n",
" return True\n",
"\n",
" else:\n",
" print(f\"GPU with compute capability {compute_capability} is not supported by flash-attn. Skipping installation.\")\n",
" return False\n",
" except Exception as e:\n",
" print(f\"An error occurred during GPU check or installation: {e}\")\n",
" return False\n",
"\n",
"def install_package(package_name, *pip_args):\n",
" \"\"\"\n",
" A helper function to install a pip package using subprocess.\n",
" \"\"\"\n",
" try:\n",
" command = [sys.executable, \"-m\", \"pip\", \"install\", package_name]\n",
" command.extend(pip_args)\n",
" subprocess.check_call(command)\n",
" print(f\"Successfully installed {package_name}.\")\n",
" except subprocess.CalledProcessError as e:\n",
" print(f\"Error installing {package_name}: {e}\")\n",
" except Exception as e:\n",
" print(f\"An unexpected error occurred: {e}\")\n",
"\n",
"is_flash_attn_available = install_flash_attn_conditionally()"
]
},
{
"cell_type": "markdown",
"id": "6ae4e32d",
"metadata": {
"papermill": {
"duration": 0.036411,
"end_time": "2025-09-11T10:57:33.696417",
"exception": false,
"start_time": "2025-09-11T10:57:33.660006",
"status": "completed"
},
"tags": [],
"id": "6ae4e32d"
},
"source": [
"We set environment variables to specify the GPU and manage tokenizer parallelism.\n",
"* Set to the desired GPU ID. \"0\" uses the first available GPU.\n",
"* Disable tokenizer parallelism to prevent potential issues with some environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1c0ef72",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:57:33.771143Z",
"iopub.status.busy": "2025-09-11T10:57:33.770790Z",
"iopub.status.idle": "2025-09-11T10:57:33.774810Z",
"shell.execute_reply": "2025-09-11T10:57:33.774082Z"
},
"papermill": {
"duration": 0.042876,
"end_time": "2025-09-11T10:57:33.776043",
"exception": false,
"start_time": "2025-09-11T10:57:33.733167",
"status": "completed"
},
"tags": [],
"id": "f1c0ef72"
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
]
},
{
"cell_type": "markdown",
"id": "69ccc08a",
"metadata": {
"papermill": {
"duration": 0.037768,
"end_time": "2025-09-11T10:57:33.850944",
"exception": false,
"start_time": "2025-09-11T10:57:33.813176",
"status": "completed"
},
"tags": [],
"id": "69ccc08a"
},
"source": [
"During training, Hugging Face libraries can produce numerous non-critical warnings. We'll suppress these to keep the output clean."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be03cb3a",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:57:33.927378Z",
"iopub.status.busy": "2025-09-11T10:57:33.926578Z",
"iopub.status.idle": "2025-09-11T10:57:33.930468Z",
"shell.execute_reply": "2025-09-11T10:57:33.929744Z"
},
"papermill": {
"duration": 0.043753,
"end_time": "2025-09-11T10:57:33.931600",
"exception": false,
"start_time": "2025-09-11T10:57:33.887847",
"status": "completed"
},
"tags": [],
"id": "be03cb3a"
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"id": "35aae894",
"metadata": {
"papermill": {
"duration": 0.037123,
"end_time": "2025-09-11T10:57:34.005930",
"exception": false,
"start_time": "2025-09-11T10:57:33.968807",
"status": "completed"
},
"tags": [],
"id": "35aae894"
},
"source": [
"### 2. Load Model and Tokenizer\n",
"Now, we load the Gemma 3 270M-it model and its corresponding tokenizer. We'll load the model in bfloat16 for memory efficiency, a capability supported by modern GPUs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8869bc6b",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:57:34.081789Z",
"iopub.status.busy": "2025-09-11T10:57:34.081469Z",
"iopub.status.idle": "2025-09-11T10:58:03.317102Z",
"shell.execute_reply": "2025-09-11T10:58:03.316454Z"
},
"papermill": {
"duration": 29.274991,
"end_time": "2025-09-11T10:58:03.318479",
"exception": false,
"start_time": "2025-09-11T10:57:34.043488",
"status": "completed"
},
"tags": [],
"id": "8869bc6b"
},
"outputs": [],
"source": [
"# General imports\n",
"import os\n",
"import random\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"# Hugging Face imports\n",
"import transformers\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed\n",
"from datasets import Dataset\n",
"from peft import LoraConfig\n",
"from trl import SFTTrainer, SFTConfig\n",
"\n",
"# Scikit-learn for evaluation\n",
"from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de50a41a",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:03.395187Z",
"iopub.status.busy": "2025-09-11T10:58:03.394889Z",
"iopub.status.idle": "2025-09-11T10:58:03.399173Z",
"shell.execute_reply": "2025-09-11T10:58:03.398388Z"
},
"papermill": {
"duration": 0.044014,
"end_time": "2025-09-11T10:58:03.400269",
"exception": false,
"start_time": "2025-09-11T10:58:03.356255",
"status": "completed"
},
"tags": [],
"id": "de50a41a"
},
"outputs": [],
"source": [
"print(f\"transformers=={transformers.__version__}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e645cb9e",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:03.476979Z",
"iopub.status.busy": "2025-09-11T10:58:03.476289Z",
"iopub.status.idle": "2025-09-11T10:58:03.485458Z",
"shell.execute_reply": "2025-09-11T10:58:03.484909Z"
},
"papermill": {
"duration": 0.048836,
"end_time": "2025-09-11T10:58:03.486636",
"exception": false,
"start_time": "2025-09-11T10:58:03.437800",
"status": "completed"
},
"tags": [],
"id": "e645cb9e"
},
"outputs": [],
"source": [
"def set_deterministic(seed):\n",
" \"\"\"Sets all seeds and CUDA settings for deterministic results.\"\"\"\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. [2, 3]\n",
" set_seed(seed)\n",
"\n",
"SEED = 0\n",
"set_deterministic(SEED)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33c0730b",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:03.571001Z",
"iopub.status.busy": "2025-09-11T10:58:03.570267Z",
"iopub.status.idle": "2025-09-11T10:58:11.556304Z",
"shell.execute_reply": "2025-09-11T10:58:11.555425Z"
},
"papermill": {
"duration": 8.029341,
"end_time": "2025-09-11T10:58:11.557454",
"exception": false,
"start_time": "2025-09-11T10:58:03.528113",
"status": "completed"
},
"tags": [],
"id": "33c0730b"
},
"outputs": [],
"source": [
"# We specify the model path on Kaggle.\n",
"GEMMA_PATH = \"lmassaron/gemma-3-270m-it-grpo-finsent\"\n",
"\n",
"# Determine the attention implementation.\n",
"# Use the faster \"flash_attention_2\" if installed, otherwise fall back to the eager implementation.\n",
"attn_implementation = \"eager\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" GEMMA_PATH,\n",
" dtype=\"auto\", # Automatically uses bfloat16 on compatible GPUs\n",
" device_map=\"auto\",\n",
" attn_implementation=attn_implementation\n",
")\n",
"\n",
"max_seq_length = 2048\n",
"tokenizer = AutoTokenizer.from_pretrained(GEMMA_PATH, max_seq_length=max_seq_length)\n",
"\n",
"# Explicitly enable use_cache for faster inference\n",
"model.config.use_cache = True\n",
"\n",
"# We use the end-of-sequence token as the padding token.\n",
"# Padding on the left is a common practice for decoder-only models.\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.padding_side = \"left\"\n",
"model.config.pad_token_id = tokenizer.pad_token_id\n",
"model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
"model.config.bos_token_id = tokenizer.bos_token_id\n",
"model.generation_config.bos_token_id = tokenizer.bos_token_id\n",
"\n",
"# Store the End-Of-Sequence token for use in prompt formatting\n",
"EOS_TOKEN = tokenizer.eos_token\n",
"\n",
"print(f\"Device: {model.device}\")\n",
"print(f\"DType: {model.dtype}\")\n",
"print(f\"Attention Implementation: {attn_implementation}\")"
]
},
{
"cell_type": "markdown",
"id": "ec0ff91d",
"metadata": {
"papermill": {
"duration": 0.037256,
"end_time": "2025-09-11T10:58:11.631892",
"exception": false,
"start_time": "2025-09-11T10:58:11.594636",
"status": "completed"
},
"tags": [],
"id": "ec0ff91d"
},
"source": [
"### 3. Prepare the Dataset\n",
"\n",
"We perform the following steps to prepare our data for fine-tuning:\n",
"\n",
"1. Load Data: Read the all-data.csv file.\n",
"2. Create Splits:\n",
"* Training Set: A balanced set of 300 examples for each sentiment (positive, neutral, negative).\n",
"* Test Set: A balanced set of 300 examples for each sentiment, separate from the training set.\n",
"* Evaluation Set: A smaller, balanced set of 50 examples per sentiment, sampled with replacement from the remaining data. This is used for validation during training.\n",
"3. Format Prompts: We convert the raw text into structured prompts that guide the model to perform the sentiment analysis task. A special prompt format is used for training (including the answer) and another for testing (without the answer).\n",
"4. Create Datasets: The prepared data is converted into Hugging Face Dataset objects, which are the standard format for the SFTTrainer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d89cbd3a",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:11.706862Z",
"iopub.status.busy": "2025-09-11T10:58:11.706238Z",
"iopub.status.idle": "2025-09-11T10:58:11.974786Z",
"shell.execute_reply": "2025-09-11T10:58:11.974070Z"
},
"papermill": {
"duration": 0.307781,
"end_time": "2025-09-11T10:58:11.976319",
"exception": false,
"start_time": "2025-09-11T10:58:11.668538",
"status": "completed"
},
"tags": [],
"id": "d89cbd3a"
},
"outputs": [],
"source": [
"# Load the dataset from the CSV file\n",
"splits = {'train': 'data/train-00000-of-00001.parquet', 'validation': 'data/validation-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}\n",
"df = pd.read_parquet(\"hf://datasets/lmassaron/FinancialPhraseBank/\" + splits[\"train\"])\n",
"df = df.rename({\"sentence\": \"text\"}, axis=1).loc[:,[\"sentiment\", \"text\"]]\n",
"\n",
"# Stratified sampling to create balanced train and test sets\n",
"X_train, X_test = [], []\n",
"for sentiment in [\"positive\", \"neutral\", \"negative\"]:\n",
" train, test = train_test_split(df[df.sentiment==sentiment],\n",
" train_size=200,\n",
" test_size=200,\n",
" random_state=42)\n",
" X_train.append(train)\n",
" X_test.append(test)\n",
"\n",
"# Concatenate and shuffle the training data\n",
"X_train = pd.concat(X_train).sample(frac=1, random_state=10)\n",
"X_test = pd.concat(X_test)\n",
"\n",
"# Create a balanced evaluation set from the remaining data\n",
"eval_idx = [idx for idx in df.index if idx not in list(train.index) + list(test.index)]\n",
"X_eval = df[df.index.isin(eval_idx)]\n",
"X_eval = (X_eval\n",
" .groupby('sentiment', group_keys=False)\n",
" .apply(lambda x: x.sample(n=50, random_state=10, replace=True)))\n",
"X_train = X_train.reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a75bb99b",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:12.053624Z",
"iopub.status.busy": "2025-09-11T10:58:12.052965Z",
"iopub.status.idle": "2025-09-11T10:58:12.057439Z",
"shell.execute_reply": "2025-09-11T10:58:12.056736Z"
},
"papermill": {
"duration": 0.044411,
"end_time": "2025-09-11T10:58:12.058693",
"exception": false,
"start_time": "2025-09-11T10:58:12.014282",
"status": "completed"
},
"tags": [],
"id": "a75bb99b"
},
"outputs": [],
"source": [
"# Prompt engineering for training and inference\n",
"\n",
"def create_training_prompt(data_point):\n",
" \"\"\"Formats a data point for training, including the expected sentiment.\"\"\"\n",
" return f\"\"\"generate_prompt\n",
" Analyze the sentiment of the news headline enclosed in square brackets,\n",
" determine if it is positive, neutral, or negative, and return the answer as\n",
" the corresponding sentiment label \"positive\" or \"neutral\" or \"negative\"\n",
"\n",
" [{data_point[\"text\"]}] = {data_point[\"sentiment\"]}\n",
" \"\"\".strip() + EOS_TOKEN\n",
"\n",
"def create_test_prompt(data_point):\n",
" \"\"\"Formats a data point for inference, leaving the sentiment for the model to generate.\"\"\"\n",
" return f\"\"\"\n",
" Analyze the sentiment of the news headline enclosed in square brackets,\n",
" determine if it is positive, neutral, or negative, and return the answer as\n",
" the corresponding sentiment label \"positive\" or \"neutral\" or \"negative\"\n",
"\n",
" [{data_point[\"text\"]}] =\n",
"\n",
" \"\"\".strip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebf63625",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:12.134984Z",
"iopub.status.busy": "2025-09-11T10:58:12.134386Z",
"iopub.status.idle": "2025-09-11T10:58:12.169769Z",
"shell.execute_reply": "2025-09-11T10:58:12.168998Z"
},
"papermill": {
"duration": 0.075233,
"end_time": "2025-09-11T10:58:12.171344",
"exception": false,
"start_time": "2025-09-11T10:58:12.096111",
"status": "completed"
},
"tags": [],
"id": "ebf63625"
},
"outputs": [],
"source": [
"# Apply prompt formatting\n",
"X_train[\"text\"] = X_train.apply(create_training_prompt, axis=1)\n",
"X_eval[\"text\"] = X_eval.apply(create_training_prompt, axis=1)\n",
"\n",
"# Store true labels for final evaluation and format test set for inference\n",
"y_true = X_test.sentiment\n",
"X_test = pd.DataFrame(X_test.apply(create_test_prompt, axis=1), columns=[\"text\"])\n",
"\n",
"# Convert pandas DataFrames to Hugging Face Dataset objects\n",
"train_data = Dataset.from_pandas(X_train)\n",
"eval_data = Dataset.from_pandas(X_eval)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "606fe49d",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:12.247683Z",
"iopub.status.busy": "2025-09-11T10:58:12.247414Z",
"iopub.status.idle": "2025-09-11T10:58:12.251666Z",
"shell.execute_reply": "2025-09-11T10:58:12.250756Z"
},
"papermill": {
"duration": 0.044318,
"end_time": "2025-09-11T10:58:12.252812",
"exception": false,
"start_time": "2025-09-11T10:58:12.208494",
"status": "completed"
},
"tags": [],
"id": "606fe49d"
},
"outputs": [],
"source": [
"print(f\"Training samples: {len(train_data)}\")\n",
"print(f\"Evaluation samples: {len(eval_data)}\")\n",
"print(f\"Test samples: {len(X_test)}\")"
]
},
{
"cell_type": "markdown",
"id": "c12678a6",
"metadata": {
"papermill": {
"duration": 0.037709,
"end_time": "2025-09-11T10:58:12.327428",
"exception": false,
"start_time": "2025-09-11T10:58:12.289719",
"status": "completed"
},
"tags": [],
"id": "c12678a6"
},
"source": [
"### 4. Define Evaluation Metrics\n",
"\n",
"We create a function to evaluate the model's predictions. This function will calculate and display:\n",
"* Overall accuracy.\n",
"* Accuracy per sentiment class.\n",
"* A detailed classification report with precision, recall, and F1-score.\n",
"* A confusion matrix to visualize where the model is making errors."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3deb45c",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:12.403615Z",
"iopub.status.busy": "2025-09-11T10:58:12.403344Z",
"iopub.status.idle": "2025-09-11T10:58:12.409857Z",
"shell.execute_reply": "2025-09-11T10:58:12.409097Z"
},
"papermill": {
"duration": 0.046253,
"end_time": "2025-09-11T10:58:12.411112",
"exception": false,
"start_time": "2025-09-11T10:58:12.364859",
"status": "completed"
},
"tags": [],
"id": "f3deb45c"
},
"outputs": [],
"source": [
"def evaluate(y_true, y_pred):\n",
" \"\"\"Calculates and prints comprehensive evaluation metrics.\"\"\"\n",
"\n",
" labels = ['positive', 'neutral', 'negative']\n",
" mapping = {'positive': 2, 'neutral': 1, 'none':1, 'negative': 0}\n",
" def map_func(x):\n",
" return mapping.get(x, 1)\n",
"\n",
" y_true = np.vectorize(map_func)(y_true)\n",
" y_pred = np.vectorize(map_func)(y_pred)\n",
"\n",
" # Calculate accuracy\n",
" accuracy = accuracy_score(y_true=y_true, y_pred=y_pred)\n",
" print(f'Accuracy: {accuracy:.3f}')\n",
"\n",
" # Generate accuracy report\n",
" unique_labels = set(y_true) # Get unique labels\n",
"\n",
" for label in unique_labels:\n",
" label_indices = [i for i in range(len(y_true))\n",
" if y_true[i] == label]\n",
" label_y_true = [y_true[i] for i in label_indices]\n",
" label_y_pred = [y_pred[i] for i in label_indices]\n",
" accuracy = accuracy_score(label_y_true, label_y_pred)\n",
" print(f'Accuracy for label {label}: {accuracy:.3f}')\n",
"\n",
" # Generate classification report\n",
" class_report = classification_report(y_true=y_true, y_pred=y_pred)\n",
" print('\\nClassification Report:')\n",
" print(class_report)\n",
"\n",
" # Generate confusion matrix\n",
" conf_matrix = confusion_matrix(y_true=y_true, y_pred=y_pred, labels=[0, 1, 2])\n",
" print('\\nConfusion Matrix:')\n",
" print(conf_matrix)"
]
},
{
"cell_type": "markdown",
"id": "fc2ea81a",
"metadata": {
"papermill": {
"duration": 0.036749,
"end_time": "2025-09-11T10:58:12.485049",
"exception": false,
"start_time": "2025-09-11T10:58:12.448300",
"status": "completed"
},
"tags": [],
"id": "fc2ea81a"
},
"source": [
"### 5. Baseline Performance (Zero-Shot)\n",
"\n",
"Before fine-tuning, let's establish a baseline by evaluating the pre-trained Gemma 3 270M-it model on our test set. This \"zero-shot\" performance demonstrates the model's ability to understand the task without any specific training.\n",
"The following prediction function is optimized to process the entire test set in batches, which is significantly faster than predicting one by one. It tokenizes the prompts, sends them to the GPU, and generates responses."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "298f7f0c",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:12.560610Z",
"iopub.status.busy": "2025-09-11T10:58:12.559941Z",
"iopub.status.idle": "2025-09-11T10:58:12.566016Z",
"shell.execute_reply": "2025-09-11T10:58:12.565407Z"
},
"papermill": {
"duration": 0.045201,
"end_time": "2025-09-11T10:58:12.567112",
"exception": false,
"start_time": "2025-09-11T10:58:12.521911",
"status": "completed"
},
"tags": [],
"id": "298f7f0c"
},
"outputs": [],
"source": [
"def predict(X_test, model, tokenizer):\n",
" \"\"\"Performs batch inference on the test set.\"\"\"\n",
"\n",
" y_pred = []\n",
" # Convert DataFrame column to a list of prompts\n",
" prompts = X_test[\"text\"].tolist()\n",
"\n",
" # Set batch size depending on GPU memory\n",
" batch_size = 8\n",
"\n",
" for i in tqdm(range(0, len(prompts), batch_size)):\n",
" batch = prompts[i:i+batch_size]\n",
" inputs = tokenizer(batch,\n",
" return_tensors=\"pt\",\n",
" padding=True,\n",
" truncation=True,\n",
" max_length=max_seq_length).to(\"cuda\")\n",
"\n",
" outputs = model.generate(\n",
" **inputs,\n",
" # Set a higher max_new_tokens to ensure the model can generate full words\n",
" max_new_tokens=10,\n",
" do_sample=False, # Use greedy decoding for deterministic output\n",
" top_p=1.0,\n",
" top_k=50,\n",
" pad_token_id=tokenizer.eos_token_id\n",
" )\n",
"\n",
" # Decode and parse the generated text\n",
" decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
"\n",
" for output in decoded_outputs:\n",
" # The generated answer is after the last '=' sign\n",
" answer = output.split(\"=\")[-1].lower().strip()\n",
"\n",
" if \"positive\" in answer:\n",
" y_pred.append(\"positive\")\n",
" elif \"negative\" in answer:\n",
" y_pred.append(\"negative\")\n",
" elif \"neutral\" in answer:\n",
" y_pred.append(\"neutral\")\n",
" else:\n",
" # Fallback for unexpected or empty outputs\n",
" y_pred.append(\"none\")\n",
"\n",
" return y_pred"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b7c4fce",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:58:12.644397Z",
"iopub.status.busy": "2025-09-11T10:58:12.643641Z",
"iopub.status.idle": "2025-09-11T10:59:08.610417Z",
"shell.execute_reply": "2025-09-11T10:59:08.609604Z"
},
"papermill": {
"duration": 56.006568,
"end_time": "2025-09-11T10:59:08.611642",
"exception": false,
"start_time": "2025-09-11T10:58:12.605074",
"status": "completed"
},
"tags": [],
"id": "9b7c4fce"
},
"outputs": [],
"source": [
"# Evaluate the base model\n",
"y_pred = predict(X_test, model, tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18231881",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:59:08.696717Z",
"iopub.status.busy": "2025-09-11T10:59:08.696421Z",
"iopub.status.idle": "2025-09-11T10:59:08.712166Z",
"shell.execute_reply": "2025-09-11T10:59:08.711203Z"
},
"papermill": {
"duration": 0.059492,
"end_time": "2025-09-11T10:59:08.713394",
"exception": false,
"start_time": "2025-09-11T10:59:08.653902",
"status": "completed"
},
"tags": [],
"id": "18231881"
},
"outputs": [],
"source": [
"evaluate(y_true, y_pred)"
]
},
{
"cell_type": "markdown",
"id": "acd626d9",
"metadata": {
"papermill": {
"duration": 0.039757,
"end_time": "2025-09-11T10:59:08.794558",
"exception": false,
"start_time": "2025-09-11T10:59:08.754801",
"status": "completed"
},
"tags": [],
"id": "acd626d9"
},
"source": [
"As expected, the base model's performance is poor. It often defaults to a single sentiment (like neutral) because it hasn't been specifically trained for this nuanced financial analysis task. This result highlights the need for fine-tuning."
]
},
{
"cell_type": "markdown",
"id": "f2d2c5ca",
"metadata": {
"papermill": {
"duration": 0.039871,
"end_time": "2025-09-11T10:59:08.874079",
"exception": false,
"start_time": "2025-09-11T10:59:08.834208",
"status": "completed"
},
"tags": [],
"id": "f2d2c5ca"
},
"source": [
"### 6. Fine-Tuning with PEFT (LoRA)\n",
"\n",
"We will use the SFTTrainer from the TRL library to perform Supervised Fine-tuning. To make this process efficient, we'll use a PEFT method called LoRA (Low-Rank Adaptation).\n",
"LoRA freezes the pre-trained model weights and injects trainable, low-rank matrices into the attention layers. We only train these small matrices, drastically reducing the number of trainable parameters and memory requirements.\n",
"Below, we define the configurations for LoRA and the trainer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fccbb739",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:59:08.954970Z",
"iopub.status.busy": "2025-09-11T10:59:08.954276Z",
"iopub.status.idle": "2025-09-11T10:59:11.026809Z",
"shell.execute_reply": "2025-09-11T10:59:11.026167Z"
},
"papermill": {
"duration": 2.114338,
"end_time": "2025-09-11T10:59:11.028173",
"exception": false,
"start_time": "2025-09-11T10:59:08.913835",
"status": "completed"
},
"tags": [],
"id": "fccbb739"
},
"outputs": [],
"source": [
"# LoRA configuration\n",
"peft_config = LoraConfig(\n",
" lora_alpha=16,\n",
" lora_dropout=0,\n",
" r=64,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
")\n",
"\n",
"# SFT (Supervised Fine-tuning) configuration\n",
"training_arguments = SFTConfig(\n",
" output_dir=\"logs\",\n",
" seed=SEED,\n",
" num_train_epochs=5,\n",
" gradient_checkpointing=True,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=8,\n",
" optim=\"adamw_torch_fused\",\n",
" save_steps=0,\n",
" logging_steps=25,\n",
" learning_rate=2e-4,\n",
" weight_decay=0.001,\n",
" fp16=True,\n",
" bf16=False,\n",
" max_grad_norm=0.3,\n",
" max_steps=-1,\n",
" warmup_ratio=0.03,\n",
" group_by_length=False,\n",
" eval_strategy='steps',\n",
" eval_steps = 112,\n",
" eval_accumulation_steps=1,\n",
" lr_scheduler_type=\"cosine\",\n",
" dataset_text_field=\"text\",\n",
" packing=False,\n",
" max_length=max_seq_length,\n",
" report_to=\"tensorboard\",\n",
")\n",
"\n",
"# Initialize the trainer\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" train_dataset=train_data,\n",
" eval_dataset=eval_data,\n",
" peft_config=peft_config,\n",
" processing_class=tokenizer,\n",
" args=training_arguments,\n",
"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b5e20a16",
"metadata": {
"papermill": {
"duration": 0.042127,
"end_time": "2025-09-11T10:59:11.114329",
"exception": false,
"start_time": "2025-09-11T10:59:11.072202",
"status": "completed"
},
"tags": [],
"id": "b5e20a16"
},
"source": [
"### 7. Start Training\n",
"\n",
"We can now start the fine-tuning process. With a T4 GPU on Kaggle, this should take around 15-20 minutes. The training loss and validation loss (if eval_dataset is provided) will be printed periodically."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0b6efc1",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T10:59:11.200596Z",
"iopub.status.busy": "2025-09-11T10:59:11.200076Z",
"iopub.status.idle": "2025-09-11T11:19:12.394339Z",
"shell.execute_reply": "2025-09-11T11:19:12.393707Z"
},
"papermill": {
"duration": 1201.239534,
"end_time": "2025-09-11T11:19:12.395840",
"exception": false,
"start_time": "2025-09-11T10:59:11.156306",
"status": "completed"
},
"tags": [],
"id": "c0b6efc1"
},
"outputs": [],
"source": [
"# Train model\n",
"trainer.train()\n",
"\n",
"# Save the fine-tuned LoRA adapter\n",
"trainer.model.save_pretrained(\"trained-model\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5bab5912",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T11:19:12.483267Z",
"iopub.status.busy": "2025-09-11T11:19:12.483002Z",
"iopub.status.idle": "2025-09-11T11:19:12.775231Z",
"shell.execute_reply": "2025-09-11T11:19:12.774357Z"
},
"papermill": {
"duration": 0.337226,
"end_time": "2025-09-11T11:19:12.776378",
"exception": false,
"start_time": "2025-09-11T11:19:12.439152",
"status": "completed"
},
"tags": [],
"id": "5bab5912"
},
"outputs": [],
"source": [
"# Access the log history\n",
"log_history = trainer.state.log_history\n",
"\n",
"# Extract training / validation loss\n",
"train_losses = [log[\"loss\"] for log in log_history if \"loss\" in log]\n",
"epoch_train = [log[\"epoch\"] for log in log_history if \"loss\" in log]\n",
"eval_losses = [log[\"eval_loss\"] for log in log_history if \"eval_loss\" in log]\n",
"epoch_eval = [log[\"epoch\"] for log in log_history if \"eval_loss\" in log]\n",
"\n",
"# Plot the training loss\n",
"plt.plot(epoch_train, train_losses, label=\"Training Loss\")\n",
"plt.plot(epoch_eval, eval_losses, label=\"Validation Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.title(\"Training and Validation Loss per Epoch\")\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "41c0420f",
"metadata": {
"papermill": {
"duration": 0.043489,
"end_time": "2025-09-11T11:19:12.863156",
"exception": false,
"start_time": "2025-09-11T11:19:12.819667",
"status": "completed"
},
"tags": [],
"id": "41c0420f"
},
"source": [
"You can monitor the training progress using TensorBoard, which provides visualizations of metrics like training loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f3b1b70",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T11:19:12.950904Z",
"iopub.status.busy": "2025-09-11T11:19:12.950311Z",
"iopub.status.idle": "2025-09-11T11:19:19.009282Z",
"shell.execute_reply": "2025-09-11T11:19:19.008562Z"
},
"papermill": {
"duration": 6.104437,
"end_time": "2025-09-11T11:19:19.010453",
"exception": false,
"start_time": "2025-09-11T11:19:12.906016",
"status": "completed"
},
"tags": [],
"id": "9f3b1b70"
},
"outputs": [],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir logs/runs"
]
},
{
"cell_type": "markdown",
"id": "55cf7fe2",
"metadata": {
"papermill": {
"duration": 0.043524,
"end_time": "2025-09-11T11:19:19.098247",
"exception": false,
"start_time": "2025-09-11T11:19:19.054723",
"status": "completed"
},
"tags": [],
"id": "55cf7fe2"
},
"source": [
"### 8. Evaluate the Fine-Tuned Model\n",
"After training is complete, the SFTTrainer automatically merges the LoRA adapter weights into the base model. We can now use the same predict function to evaluate its performance on the test set. We should see a dramatic improvement over the baseline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1d30bf5",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T11:19:19.187355Z",
"iopub.status.busy": "2025-09-11T11:19:19.186599Z",
"iopub.status.idle": "2025-09-11T11:19:37.276048Z",
"shell.execute_reply": "2025-09-11T11:19:37.275170Z"
},
"papermill": {
"duration": 18.135627,
"end_time": "2025-09-11T11:19:37.277253",
"exception": false,
"start_time": "2025-09-11T11:19:19.141626",
"status": "completed"
},
"tags": [],
"id": "b1d30bf5"
},
"outputs": [],
"source": [
"# Set model configuration for inference\n",
"model.gradient_checkpointing_disable()\n",
"model.config.use_cache = True\n",
"\n",
"y_pred = predict(X_test, model, tokenizer)\n",
"evaluate(y_true, y_pred)"
]
},
{
"cell_type": "markdown",
"id": "f357966d",
"metadata": {
"papermill": {
"duration": 0.110528,
"end_time": "2025-09-11T11:19:37.435856",
"exception": false,
"start_time": "2025-09-11T11:19:37.325328",
"status": "completed"
},
"tags": [],
"id": "f357966d"
},
"source": [
"The results should show a significant increase in accuracy, precision, and recall across all sentiment classes. This demonstrates the power of fine-tuning for adapting a general-purpose model to a specific domain and task."
]
},
{
"cell_type": "markdown",
"id": "a7ece766",
"metadata": {
"papermill": {
"duration": 0.046844,
"end_time": "2025-09-11T11:19:37.530209",
"exception": false,
"start_time": "2025-09-11T11:19:37.483365",
"status": "completed"
},
"tags": [],
"id": "a7ece766"
},
"source": [
"### 9. Analyze Predictions\n",
"Finally, let's create a CSV file containing the original text, the true labels, and the model's predictions. This is useful for error analysis—examining the specific cases where the model failed can provide insights for further improvements, such as refining the prompt or adding more diverse training data."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51898fc1",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-11T11:19:37.628017Z",
"iopub.status.busy": "2025-09-11T11:19:37.627291Z",
"iopub.status.idle": "2025-09-11T11:19:37.661665Z",
"shell.execute_reply": "2025-09-11T11:19:37.661021Z"
},
"papermill": {
"duration": 0.083908,
"end_time": "2025-09-11T11:19:37.662889",
"exception": false,
"start_time": "2025-09-11T11:19:37.578981",
"status": "completed"
},
"tags": [],
"id": "51898fc1"
},
"outputs": [],
"source": [
"evaluation_df = pd.DataFrame({'text': X_test[\"text\"],\n",
" 'y_true':y_true,\n",
" 'y_pred': y_pred},\n",
" )\n",
"evaluation_df.to_csv(\"test_predictions.csv\", index=False)\n",
"\n",
"print(\"Predictions saved to test_predictions.csv\")\n",
"evaluation_df.head()"
]
}
],
"metadata": {
"kaggle": {
"accelerator": "gpu",
"dataSources": [
{
"datasetId": 622510,
"sourceId": 1192499,
"sourceType": "datasetVersion"
},
{
"isSourceIdPinned": true,
"modelId": 222398,
"modelInstanceId": 410134,
"sourceId": 521642,
"sourceType": "modelInstanceVersion"
}
],
"isGpuEnabled": true,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
},
"papermill": {
"default_parameters": {},
"duration": 1426.543154,
"end_time": "2025-09-11T11:19:40.533795",
"environment_variables": {},
"exception": null,
"input_path": "__notebook__.ipynb",
"output_path": "__notebook__.ipynb",
"parameters": {},
"start_time": "2025-09-11T10:55:53.990641",
"version": "2.4.0"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment