Skip to content

Instantly share code, notes, and snippets.

@Abe404
Created February 14, 2026 08:05
Show Gist options
  • Select an option

  • Save Abe404/17c3b702308a787b1f7622b8b418bd64 to your computer and use it in GitHub Desktop.

Select an option

Save Abe404/17c3b702308a787b1f7622b8b418bd64 to your computer and use it in GitHub Desktop.
PyTorch binary image classification with early stopping (ResNet-18)
# train.py
import os
import glob
import random
import csv
import json
import hashlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils import (
get_label_from_filename,
compute_confusion_stats,
prf1_from_confusion,
subset,
SeedDataset,
get_transforms,
make_model,
run_epoch,
)
# -------------------------
# Config
# -------------------------
MODEL_NAME = "resnet18"
DATA_DIR = "../../data/labeled_images_rgb" # folder with your PNGs
BATCH_SIZE = 126
NUM_EPOCHS = 250
LR = 1e-6
TRAIN_FRAC = 0.7
VAL_FRAC = 0.15 # rest is test (held out)
SEED = 42
PATIENCE = 15
OUT_DIR = "output" # keep outputs local to this experiment folder
CSV_PATH = os.path.join(OUT_DIR, "metrics_log.csv")
SPLIT_PATH = os.path.join(OUT_DIR, "splits_seed42.json")
CFG_PATH = os.path.join(OUT_DIR, "run_config.json")
MODELS_DIR = os.path.join(OUT_DIR, "models")
MPATH = os.path.join(MODELS_DIR, f"{MODEL_NAME}_rgb_best.pt")
# -------------------------
# Reproducibility
# -------------------------
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def make_run_config():
return {
"experiment": "rgb_full",
"wavelength": "rgb",
"MODEL_NAME": MODEL_NAME,
"DATA_DIR": DATA_DIR,
"lr": LR,
"batch_size": BATCH_SIZE,
"seed": SEED,
"train_frac": TRAIN_FRAC,
"val_frac": VAL_FRAC,
"patience": PATIENCE,
"num_epochs": NUM_EPOCHS,
}
def train():
os.makedirs(MODELS_DIR, exist_ok=True)
cfg = make_run_config()
# -------------------------
# Collect image paths + labels
# -------------------------
all_paths = sorted(glob.glob(os.path.join(DATA_DIR, "*.png")))
if not all_paths:
raise RuntimeError(f"No PNG files found in {DATA_DIR}")
# dataset signature (stable)
cfg["files_sig"] = hashlib.sha1(
"\n".join(os.path.basename(p) for p in all_paths).encode("utf-8")
).hexdigest()
run_info = {
"wavelength": "rgb",
"out_dir": OUT_DIR,
"csv_path": CSV_PATH,
"split_path": SPLIT_PATH,
"config_path": CFG_PATH,
"model_path": MPATH,
"model_name": MODEL_NAME,
"data_dir": DATA_DIR
}
# cache: skip if config unchanged + outputs exist
if os.path.isfile(CFG_PATH) and os.path.isfile(CSV_PATH) and os.path.isfile(MPATH):
with open(CFG_PATH) as f:
old = json.load(f)
if old == cfg:
print("[cache] Skipping training (config unchanged)")
return run_info
print(f"Found {len(all_paths)} images in {DATA_DIR}")
all_labels = [get_label_from_filename(p) for p in all_paths]
# -------------------------
# Train / val / test split
# -------------------------
random.seed(SEED)
indices = list(range(len(all_paths)))
random.shuffle(indices)
n = len(indices)
n_train = int(TRAIN_FRAC * n)
n_val = int(VAL_FRAC * n)
n_test = n - n_train - n_val
train_idx = indices[:n_train]
val_idx = indices[n_train:n_train + n_val]
test_idx = indices[n_train + n_val:]
train_paths, train_labels = subset(all_paths, all_labels, train_idx)
val_paths, val_labels = subset(all_paths, all_labels, val_idx)
print(f"Split sizes: train={len(train_paths)}, val={len(val_paths)}, test={n_test} (held out)")
splits = {
"seed": SEED,
"train_idx": train_idx,
"val_idx": val_idx,
"test_idx": test_idx,
}
with open(SPLIT_PATH, "w") as f:
json.dump(splits, f)
# -------------------------
# Datasets & loaders
# -------------------------
train_transform, val_test_transform = get_transforms()
train_ds = SeedDataset(train_paths, train_labels, transform=train_transform)
val_ds = SeedDataset(val_paths, val_labels, transform=val_test_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
# -------------------------
# Model, loss, optimizer
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print(f"Model: {MODEL_NAME}")
model = make_model(MODEL_NAME, num_classes=2, pretrained=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# -------------------------
# CSV logging
# -------------------------
with open(CSV_PATH, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow([
"epoch",
"split",
"accuracy",
"precision_pos",
"recall_pos",
"f1_pos",
"TP",
"FP",
"TN",
"FN",
])
# -------------------------
# Training loop with early stopping on val accuracy
# -------------------------
best_val_acc = -1.0
epochs_no_improve = 0
for epoch in range(1, NUM_EPOCHS + 1):
train_loss, train_acc, train_y, train_p = run_epoch(
train_loader, model, criterion, optimizer=optimizer, device=device
)
val_loss, val_acc, val_y, val_p = run_epoch(
val_loader, model, criterion, optimizer=None, device=device
)
train_TP, train_FP, train_TN, train_FN = compute_confusion_stats(train_y, train_p)
val_TP, val_FP, val_TN, val_FN = compute_confusion_stats(val_y, val_p)
train_prec, train_rec, train_f1 = prf1_from_confusion(train_TP, train_FP, train_FN)
val_prec, val_rec, val_f1 = prf1_from_confusion(val_TP, val_FP, val_FN)
print(
f"Epoch {epoch:03d}/{NUM_EPOCHS} | "
f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
f"train_f1={train_f1:.4f} | "
f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} "
f"val_f1={val_f1:.4f}"
)
with open(CSV_PATH, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([epoch, "train", train_acc, train_prec, train_rec, train_f1, train_TP, train_FP, train_TN, train_FN])
writer.writerow([epoch, "val", val_acc, val_prec, val_rec, val_f1, val_TP, val_FP, val_TN, val_FN])
if val_acc > best_val_acc:
best_val_acc = val_acc
epochs_no_improve = 0
checkpoint = {"model_name": MODEL_NAME, "state_dict": model.state_dict()}
torch.save(checkpoint, MPATH)
print(f"Best so far: val_acc={best_val_acc:.4f}. Best model saved to {MPATH}")
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Early stopping at epoch {epoch} (best val_acc={best_val_acc:.4f})")
break
with open(CFG_PATH, "w") as f:
json.dump(cfg, f, indent=2, sort_keys=True)
return run_info
if __name__ == "__main__":
train()
# utils.py
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import models
from torchvision.models import (
ResNet18_Weights,
ResNet50_Weights,
ViT_B_16_Weights,
)
from torchvision import transforms
from PIL import Image
import torchvision.transforms.functional as F
MODEL_NAMES = ["resnet18", "resnet50", "vit_b_16"]
def get_label_from_filename(path: str) -> int:
"""
Expects filenames like:
'FA 17 11 2025_BlobImage1657_rgb_has_emerged_False.png'
'FP 18 11 2025_BlobImage2007_rgb_has_emerged_True.png'
Returns 1 for emerged, 0 for not emerged.
"""
name = Path(path).stem # drop .png
parts = name.split("_")
flag_str = parts[-1]
if flag_str == "True":
return 1
elif flag_str == "False":
return 0
else:
raise ValueError(f"Cannot parse label from filename: {path}")
def compute_confusion_stats(labels, preds):
"""
labels, preds: lists of 0/1 ints.
Returns TP, FP, TN, FN for class 1 ("emerged").
"""
TP = FP = TN = FN = 0
for y, p in zip(labels, preds):
if y == 1 and p == 1:
TP += 1
elif y == 0 and p == 1:
FP += 1
elif y == 0 and p == 0:
TN += 1
elif y == 1 and p == 0:
FN += 1
return TP, FP, TN, FN
def _safe_div(a, b):
return a / b if b != 0 else 0.0
def prf1_from_confusion(TP, FP, FN):
"""
Precision/recall/F1 for the positive class (emerged = 1).
"""
precision = _safe_div(TP, TP + FP)
recall = _safe_div(TP, TP + FN)
if precision + recall == 0.0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return precision, recall, f1
def subset(paths, labels, idx_list):
return [paths[i] for i in idx_list], [labels[i] for i in idx_list]
class SeedDataset(Dataset):
def __init__(self, paths, labels, transform=None):
self.paths = paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img_path = self.paths[idx]
label = self.labels[idx]
img = Image.open(img_path).convert("RGB")
if self.transform is not None:
img = self.transform(img)
return img, label
def get_transforms():
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
#transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(
brightness=0.02,
contrast=0.02,
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
val_test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
return train_transform, val_test_transform
def make_model(name: str, num_classes: int = 2, pretrained: bool = True):
"""
Factory for known models. Handles architecture-specific classifier heads.
"""
if name == "resnet18":
weights = ResNet18_Weights.DEFAULT if pretrained else None
m = models.resnet18(weights=weights)
m.fc = nn.Linear(m.fc.in_features, num_classes)
elif name == "resnet50":
weights = ResNet50_Weights.DEFAULT if pretrained else None
m = models.resnet50(weights=weights)
m.fc = nn.Linear(m.fc.in_features, num_classes)
elif name == "vit_b_16":
weights = ViT_B_16_Weights.DEFAULT if pretrained else None
m = models.vit_b_16(weights=weights)
# ViT uses heads.head as classifier
m.heads.head = nn.Linear(m.heads.head.in_features, num_classes)
else:
raise ValueError(f"Unknown model name: {name}. Expected one of {MODEL_NAMES}")
return m
def run_epoch(loader, model, criterion, optimizer=None, device="cpu"):
train_mode = optimizer is not None
model.train() if train_mode else model.eval()
total_loss = 0.0
all_preds, all_labels = [], []
for imgs, labels in loader:
imgs = imgs.to(device)
labels = labels.to(device)
if train_mode:
optimizer.zero_grad()
with torch.set_grad_enabled(train_mode):
outputs = model(imgs)
loss = criterion(outputs, labels)
if train_mode:
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = outputs.argmax(dim=1)
all_preds.extend(preds.detach().cpu().tolist())
all_labels.extend(labels.detach().cpu().tolist())
avg_loss = total_loss / len(all_labels)
accuracy = sum(p == y for p, y in zip(all_preds, all_labels)) / len(all_labels)
return avg_loss, accuracy, all_labels, all_preds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment