Created
August 9, 2025 12:05
-
-
Save mrbesher/cb1f10e0964f5ea3bd6e09215f1df96d to your computer and use it in GitHub Desktop.
CLIP-style finetuning of Jina V2 on a Finnish captions dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import accelerate | |
| import requests | |
| import numpy as np | |
| from PIL import Image as PILImage | |
| from tqdm.notebook import tqdm | |
| from datasets import load_dataset | |
| from google.colab import userdata | |
| from huggingface_hub import login, create_repo | |
| from transformers import ( | |
| AutoModel, | |
| AutoProcessor, | |
| TrainingArguments, | |
| Trainer | |
| ) | |
| from peft import LoraConfig, get_peft_model | |
| # --- Auth --- | |
| HF_USER = "mrbesher" | |
| HF_TOKEN = userdata.get("hf_write") | |
| login(token=HF_TOKEN) | |
| # --- Config --- | |
| MODEL_CKPT = "jinaai/jina-clip-v2" | |
| FINNISH_DATASET = "mrbesher/flickr-fi" | |
| ORIGINAL_DATASET = "nlphuji/flickr30k" | |
| OUTPUT_DIR = "./siglip2-flickr-fi-finetuned" | |
| MODEL_ID = f"{HF_USER}/jina-clip-v2-fi-flickr-lora" | |
| # --- Load Model & Processor --- | |
| print("Loading model and processor...") | |
| processor = AutoProcessor.from_pretrained(MODEL_CKPT, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| MODEL_CKPT, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| use_text_flash_attn=False, | |
| trust_remote_code=True | |
| ) | |
| # --- Load Data --- | |
| print("Loading datasets...") | |
| fi_data = load_dataset(FINNISH_DATASET) | |
| orig_test_data = load_dataset(ORIGINAL_DATASET, split="test") | |
| print(f"Finnish dataset keys: {fi_data['train'][0].keys()}") | |
| print(f"Original dataset keys: {orig_test_data[0].keys()}") | |
| # Map img_id to index for lookup | |
| img_index_map = {item["img_id"]: idx for idx, item in enumerate(orig_test_data)} | |
| print(f"Image index map size: {len(img_index_map)}") | |
| def attach_images(example): | |
| try: | |
| idx = img_index_map.get(example["img_id"]) | |
| if idx is None: | |
| return example | |
| orig_item = orig_test_data[idx] | |
| image_data = orig_item["image"] | |
| if isinstance(image_data, str): | |
| image_data = PILImage.open(image_data).convert("RGB") | |
| elif not isinstance(image_data, PILImage.Image): | |
| raise ValueError(f"Unsupported image type: {type(image_data)}") | |
| example["image"] = image_data | |
| if isinstance(example["finnish_caption"], list): | |
| example["finnish_caption"] = example["finnish_caption"][0] | |
| except Exception as e: | |
| print(f"Error with img_id {example.get('img_id')}: {e}") | |
| example["image"] = None | |
| return example | |
| print("Attaching images to Finnish dataset...") | |
| fi_data_images = fi_data["train"].map(attach_images, num_proc=4) | |
| fi_data_images = fi_data_images.filter(lambda x: x["image"] is not None) | |
| print(f"Final dataset size: {len(fi_data_images)}") | |
| # Train/Val split | |
| split_data = fi_data_images.train_test_split(test_size=0.1) | |
| train_data, val_data = split_data["train"], split_data["test"] | |
| # --- Collate Function --- | |
| def collate_batch(examples): | |
| texts = [e["finnish_caption"] for e in examples] | |
| images = [e["image"] for e in examples] | |
| inputs = processor(text=texts, images=images, return_tensors="pt", | |
| padding=True, truncation=True, max_length=64) | |
| inputs = inputs.to(next(model.parameters()).device) | |
| inputs["return_loss"] = True | |
| return inputs | |
| # --- Eval --- | |
| def eval_text_to_image(model, proc, queries, image_set, top_k=5, batch_size=48): | |
| model.eval() | |
| device = next(model.parameters()).device | |
| # Query embeddings | |
| query_embeds = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(queries), batch_size): | |
| batch = queries[i:i+batch_size] | |
| inputs = proc(text=batch, return_tensors="pt", padding=True, truncation=True).to(device) | |
| query_embeds.append(model.get_text_features(**inputs).cpu()) | |
| query_embeds = torch.cat(query_embeds, dim=0) | |
| # Image embeddings | |
| img_embeds, img_captions = [], [] | |
| def image_collate(batch): | |
| imgs = [b["image"] for b in batch if b["image"] is not None] | |
| caps = [(b["caption"][0] if isinstance(b["caption"], list) else b["caption"]) | |
| if b.get("caption") else f"Image {b.get('img_id')}" for b in batch if b["image"] is not None] | |
| if not imgs: | |
| return None, None | |
| return proc(images=imgs, return_tensors="pt").pixel_values, caps | |
| loader = torch.utils.data.DataLoader(image_set, batch_size=batch_size, collate_fn=image_collate) | |
| with torch.no_grad(): | |
| for pixels, caps in tqdm(loader, desc="Embedding Images"): | |
| if pixels is None: continue | |
| pixels = pixels.to(device) | |
| img_embeds.append(model.get_image_features(pixel_values=pixels).cpu()) | |
| img_captions.extend(caps) | |
| if not img_embeds: | |
| print("No images embedded.") | |
| return | |
| img_embeds = torch.cat(img_embeds, dim=0) | |
| query_embeds = query_embeds.to(device) / query_embeds.norm(dim=-1, keepdim=True) | |
| img_embeds = img_embeds.to(device) / img_embeds.norm(dim=-1, keepdim=True) | |
| sims = query_embeds @ img_embeds.T | |
| scores, indices = torch.topk(sims, k=top_k, dim=1) | |
| for i, q in enumerate(queries): | |
| print(f"\nQuery: {q}") | |
| for j in range(top_k): | |
| idx = indices[i][j].item() | |
| score = scores[i][j].item() | |
| print(f" {j+1}. {score:.4f} - {img_captions[idx]}") | |
| from IPython.display import display | |
| display(image_set[idx]["image"].resize((150, 150))) | |
| model.train() | |
| # --- Prepare Eval Data --- | |
| eval_count = 750 | |
| if len(orig_test_data) < eval_count: | |
| eval_images = orig_test_data | |
| else: | |
| eval_images = orig_test_data.select(range(len(orig_test_data) - eval_count, len(orig_test_data))) | |
| queries = [ | |
| "Mies hyppää ilmaan rannalla", | |
| "Nainen istuu kahvilassa ja lukee kirjaa", | |
| "Koira juoksee nurmikolla pallon kanssa", | |
| "Lapset leikkivät lumessa talvipäivänä", | |
| "Aurinko laskee vuorten taakse järven yllä", | |
| "A man jumps in the air at the beach", | |
| "A woman sits in a cafe and reads a book", | |
| "A dog runs on the grass with a ball", | |
| "Children play in the snow on a winter day", | |
| "The sun sets behind the mountains over the lake" | |
| ] | |
| def prep_eval_images(dataset): | |
| out = [] | |
| for i, item in enumerate(dataset): | |
| if item.get("image"): | |
| caps = item["caption"] if isinstance(item["caption"], list) else [item["caption"]] | |
| out.append({"image": item["image"], "caption": caps, "img_id": item.get("img_id", i)}) | |
| return out | |
| eval_images = prep_eval_images(eval_images) | |
| # --- Eval Before Training --- | |
| print("\n--- Eval Before Training ---") | |
| eval_text_to_image(model, processor, queries, eval_images, top_k=5) | |
| # --- Create Repo --- | |
| create_repo(MODEL_ID, private=False, token=HF_TOKEN, exist_ok=True) | |
| # --- Training --- | |
| train_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=5, | |
| per_device_train_batch_size=4, | |
| per_device_eval_batch_size=16, | |
| gradient_accumulation_steps=4, | |
| gradient_checkpointing=True, | |
| learning_rate=1e-4, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.1, | |
| logging_steps=20, | |
| save_strategy="epoch", | |
| eval_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="loss", | |
| report_to="tensorboard", | |
| remove_unused_columns=False, | |
| dataloader_pin_memory=False, | |
| push_to_hub=True, | |
| hub_model_id=MODEL_ID, | |
| hub_token=HF_TOKEN, | |
| weight_decay=0.01, | |
| max_grad_norm=1.0, | |
| save_total_limit=2, | |
| optim="adamw_torch", | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=train_args, | |
| train_dataset=train_data, | |
| eval_dataset=val_data, | |
| data_collator=collate_batch, | |
| ) | |
| print("Training...") | |
| trainer.train() | |
| trainer.push_to_hub() | |
| # --- Eval After Training --- | |
| print("\n--- Eval After Training ---") | |
| eval_text_to_image(trainer.model, processor, queries, eval_images, top_k=5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment