Skip to content

Instantly share code, notes, and snippets.

@mrbesher
Created August 9, 2025 12:05
Show Gist options
  • Select an option

  • Save mrbesher/cb1f10e0964f5ea3bd6e09215f1df96d to your computer and use it in GitHub Desktop.

Select an option

Save mrbesher/cb1f10e0964f5ea3bd6e09215f1df96d to your computer and use it in GitHub Desktop.
CLIP-style finetuning of Jina V2 on a Finnish captions dataset
import os
import torch
import accelerate
import requests
import numpy as np
from PIL import Image as PILImage
from tqdm.notebook import tqdm
from datasets import load_dataset
from google.colab import userdata
from huggingface_hub import login, create_repo
from transformers import (
AutoModel,
AutoProcessor,
TrainingArguments,
Trainer
)
from peft import LoraConfig, get_peft_model
# --- Auth ---
HF_USER = "mrbesher"
HF_TOKEN = userdata.get("hf_write")
login(token=HF_TOKEN)
# --- Config ---
MODEL_CKPT = "jinaai/jina-clip-v2"
FINNISH_DATASET = "mrbesher/flickr-fi"
ORIGINAL_DATASET = "nlphuji/flickr30k"
OUTPUT_DIR = "./siglip2-flickr-fi-finetuned"
MODEL_ID = f"{HF_USER}/jina-clip-v2-fi-flickr-lora"
# --- Load Model & Processor ---
print("Loading model and processor...")
processor = AutoProcessor.from_pretrained(MODEL_CKPT, trust_remote_code=True)
model = AutoModel.from_pretrained(
MODEL_CKPT,
torch_dtype=torch.bfloat16,
device_map="auto",
use_text_flash_attn=False,
trust_remote_code=True
)
# --- Load Data ---
print("Loading datasets...")
fi_data = load_dataset(FINNISH_DATASET)
orig_test_data = load_dataset(ORIGINAL_DATASET, split="test")
print(f"Finnish dataset keys: {fi_data['train'][0].keys()}")
print(f"Original dataset keys: {orig_test_data[0].keys()}")
# Map img_id to index for lookup
img_index_map = {item["img_id"]: idx for idx, item in enumerate(orig_test_data)}
print(f"Image index map size: {len(img_index_map)}")
def attach_images(example):
try:
idx = img_index_map.get(example["img_id"])
if idx is None:
return example
orig_item = orig_test_data[idx]
image_data = orig_item["image"]
if isinstance(image_data, str):
image_data = PILImage.open(image_data).convert("RGB")
elif not isinstance(image_data, PILImage.Image):
raise ValueError(f"Unsupported image type: {type(image_data)}")
example["image"] = image_data
if isinstance(example["finnish_caption"], list):
example["finnish_caption"] = example["finnish_caption"][0]
except Exception as e:
print(f"Error with img_id {example.get('img_id')}: {e}")
example["image"] = None
return example
print("Attaching images to Finnish dataset...")
fi_data_images = fi_data["train"].map(attach_images, num_proc=4)
fi_data_images = fi_data_images.filter(lambda x: x["image"] is not None)
print(f"Final dataset size: {len(fi_data_images)}")
# Train/Val split
split_data = fi_data_images.train_test_split(test_size=0.1)
train_data, val_data = split_data["train"], split_data["test"]
# --- Collate Function ---
def collate_batch(examples):
texts = [e["finnish_caption"] for e in examples]
images = [e["image"] for e in examples]
inputs = processor(text=texts, images=images, return_tensors="pt",
padding=True, truncation=True, max_length=64)
inputs = inputs.to(next(model.parameters()).device)
inputs["return_loss"] = True
return inputs
# --- Eval ---
def eval_text_to_image(model, proc, queries, image_set, top_k=5, batch_size=48):
model.eval()
device = next(model.parameters()).device
# Query embeddings
query_embeds = []
with torch.no_grad():
for i in range(0, len(queries), batch_size):
batch = queries[i:i+batch_size]
inputs = proc(text=batch, return_tensors="pt", padding=True, truncation=True).to(device)
query_embeds.append(model.get_text_features(**inputs).cpu())
query_embeds = torch.cat(query_embeds, dim=0)
# Image embeddings
img_embeds, img_captions = [], []
def image_collate(batch):
imgs = [b["image"] for b in batch if b["image"] is not None]
caps = [(b["caption"][0] if isinstance(b["caption"], list) else b["caption"])
if b.get("caption") else f"Image {b.get('img_id')}" for b in batch if b["image"] is not None]
if not imgs:
return None, None
return proc(images=imgs, return_tensors="pt").pixel_values, caps
loader = torch.utils.data.DataLoader(image_set, batch_size=batch_size, collate_fn=image_collate)
with torch.no_grad():
for pixels, caps in tqdm(loader, desc="Embedding Images"):
if pixels is None: continue
pixels = pixels.to(device)
img_embeds.append(model.get_image_features(pixel_values=pixels).cpu())
img_captions.extend(caps)
if not img_embeds:
print("No images embedded.")
return
img_embeds = torch.cat(img_embeds, dim=0)
query_embeds = query_embeds.to(device) / query_embeds.norm(dim=-1, keepdim=True)
img_embeds = img_embeds.to(device) / img_embeds.norm(dim=-1, keepdim=True)
sims = query_embeds @ img_embeds.T
scores, indices = torch.topk(sims, k=top_k, dim=1)
for i, q in enumerate(queries):
print(f"\nQuery: {q}")
for j in range(top_k):
idx = indices[i][j].item()
score = scores[i][j].item()
print(f" {j+1}. {score:.4f} - {img_captions[idx]}")
from IPython.display import display
display(image_set[idx]["image"].resize((150, 150)))
model.train()
# --- Prepare Eval Data ---
eval_count = 750
if len(orig_test_data) < eval_count:
eval_images = orig_test_data
else:
eval_images = orig_test_data.select(range(len(orig_test_data) - eval_count, len(orig_test_data)))
queries = [
"Mies hyppää ilmaan rannalla",
"Nainen istuu kahvilassa ja lukee kirjaa",
"Koira juoksee nurmikolla pallon kanssa",
"Lapset leikkivät lumessa talvipäivänä",
"Aurinko laskee vuorten taakse järven yllä",
"A man jumps in the air at the beach",
"A woman sits in a cafe and reads a book",
"A dog runs on the grass with a ball",
"Children play in the snow on a winter day",
"The sun sets behind the mountains over the lake"
]
def prep_eval_images(dataset):
out = []
for i, item in enumerate(dataset):
if item.get("image"):
caps = item["caption"] if isinstance(item["caption"], list) else [item["caption"]]
out.append({"image": item["image"], "caption": caps, "img_id": item.get("img_id", i)})
return out
eval_images = prep_eval_images(eval_images)
# --- Eval Before Training ---
print("\n--- Eval Before Training ---")
eval_text_to_image(model, processor, queries, eval_images, top_k=5)
# --- Create Repo ---
create_repo(MODEL_ID, private=False, token=HF_TOKEN, exist_ok=True)
# --- Training ---
train_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=16,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
logging_steps=20,
save_strategy="epoch",
eval_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="loss",
report_to="tensorboard",
remove_unused_columns=False,
dataloader_pin_memory=False,
push_to_hub=True,
hub_model_id=MODEL_ID,
hub_token=HF_TOKEN,
weight_decay=0.01,
max_grad_norm=1.0,
save_total_limit=2,
optim="adamw_torch",
)
trainer = Trainer(
model=model,
args=train_args,
train_dataset=train_data,
eval_dataset=val_data,
data_collator=collate_batch,
)
print("Training...")
trainer.train()
trainer.push_to_hub()
# --- Eval After Training ---
print("\n--- Eval After Training ---")
eval_text_to_image(trainer.model, processor, queries, eval_images, top_k=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment