Created
February 14, 2026 08:05
-
-
Save Abe404/17c3b702308a787b1f7622b8b418bd64 to your computer and use it in GitHub Desktop.
PyTorch binary image classification with early stopping (ResNet-18)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # train.py | |
| import os | |
| import glob | |
| import random | |
| import csv | |
| import json | |
| import hashlib | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from utils import ( | |
| get_label_from_filename, | |
| compute_confusion_stats, | |
| prf1_from_confusion, | |
| subset, | |
| SeedDataset, | |
| get_transforms, | |
| make_model, | |
| run_epoch, | |
| ) | |
| # ------------------------- | |
| # Config | |
| # ------------------------- | |
| MODEL_NAME = "resnet18" | |
| DATA_DIR = "../../data/labeled_images_rgb" # folder with your PNGs | |
| BATCH_SIZE = 126 | |
| NUM_EPOCHS = 250 | |
| LR = 1e-6 | |
| TRAIN_FRAC = 0.7 | |
| VAL_FRAC = 0.15 # rest is test (held out) | |
| SEED = 42 | |
| PATIENCE = 15 | |
| OUT_DIR = "output" # keep outputs local to this experiment folder | |
| CSV_PATH = os.path.join(OUT_DIR, "metrics_log.csv") | |
| SPLIT_PATH = os.path.join(OUT_DIR, "splits_seed42.json") | |
| CFG_PATH = os.path.join(OUT_DIR, "run_config.json") | |
| MODELS_DIR = os.path.join(OUT_DIR, "models") | |
| MPATH = os.path.join(MODELS_DIR, f"{MODEL_NAME}_rgb_best.pt") | |
| # ------------------------- | |
| # Reproducibility | |
| # ------------------------- | |
| random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| torch.cuda.manual_seed_all(SEED) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def make_run_config(): | |
| return { | |
| "experiment": "rgb_full", | |
| "wavelength": "rgb", | |
| "MODEL_NAME": MODEL_NAME, | |
| "DATA_DIR": DATA_DIR, | |
| "lr": LR, | |
| "batch_size": BATCH_SIZE, | |
| "seed": SEED, | |
| "train_frac": TRAIN_FRAC, | |
| "val_frac": VAL_FRAC, | |
| "patience": PATIENCE, | |
| "num_epochs": NUM_EPOCHS, | |
| } | |
| def train(): | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| cfg = make_run_config() | |
| # ------------------------- | |
| # Collect image paths + labels | |
| # ------------------------- | |
| all_paths = sorted(glob.glob(os.path.join(DATA_DIR, "*.png"))) | |
| if not all_paths: | |
| raise RuntimeError(f"No PNG files found in {DATA_DIR}") | |
| # dataset signature (stable) | |
| cfg["files_sig"] = hashlib.sha1( | |
| "\n".join(os.path.basename(p) for p in all_paths).encode("utf-8") | |
| ).hexdigest() | |
| run_info = { | |
| "wavelength": "rgb", | |
| "out_dir": OUT_DIR, | |
| "csv_path": CSV_PATH, | |
| "split_path": SPLIT_PATH, | |
| "config_path": CFG_PATH, | |
| "model_path": MPATH, | |
| "model_name": MODEL_NAME, | |
| "data_dir": DATA_DIR | |
| } | |
| # cache: skip if config unchanged + outputs exist | |
| if os.path.isfile(CFG_PATH) and os.path.isfile(CSV_PATH) and os.path.isfile(MPATH): | |
| with open(CFG_PATH) as f: | |
| old = json.load(f) | |
| if old == cfg: | |
| print("[cache] Skipping training (config unchanged)") | |
| return run_info | |
| print(f"Found {len(all_paths)} images in {DATA_DIR}") | |
| all_labels = [get_label_from_filename(p) for p in all_paths] | |
| # ------------------------- | |
| # Train / val / test split | |
| # ------------------------- | |
| random.seed(SEED) | |
| indices = list(range(len(all_paths))) | |
| random.shuffle(indices) | |
| n = len(indices) | |
| n_train = int(TRAIN_FRAC * n) | |
| n_val = int(VAL_FRAC * n) | |
| n_test = n - n_train - n_val | |
| train_idx = indices[:n_train] | |
| val_idx = indices[n_train:n_train + n_val] | |
| test_idx = indices[n_train + n_val:] | |
| train_paths, train_labels = subset(all_paths, all_labels, train_idx) | |
| val_paths, val_labels = subset(all_paths, all_labels, val_idx) | |
| print(f"Split sizes: train={len(train_paths)}, val={len(val_paths)}, test={n_test} (held out)") | |
| splits = { | |
| "seed": SEED, | |
| "train_idx": train_idx, | |
| "val_idx": val_idx, | |
| "test_idx": test_idx, | |
| } | |
| with open(SPLIT_PATH, "w") as f: | |
| json.dump(splits, f) | |
| # ------------------------- | |
| # Datasets & loaders | |
| # ------------------------- | |
| train_transform, val_test_transform = get_transforms() | |
| train_ds = SeedDataset(train_paths, train_labels, transform=train_transform) | |
| val_ds = SeedDataset(val_paths, val_labels, transform=val_test_transform) | |
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) | |
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) | |
| # ------------------------- | |
| # Model, loss, optimizer | |
| # ------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", device) | |
| print(f"Model: {MODEL_NAME}") | |
| model = make_model(MODEL_NAME, num_classes=2, pretrained=True).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) | |
| # ------------------------- | |
| # CSV logging | |
| # ------------------------- | |
| with open(CSV_PATH, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow([ | |
| "epoch", | |
| "split", | |
| "accuracy", | |
| "precision_pos", | |
| "recall_pos", | |
| "f1_pos", | |
| "TP", | |
| "FP", | |
| "TN", | |
| "FN", | |
| ]) | |
| # ------------------------- | |
| # Training loop with early stopping on val accuracy | |
| # ------------------------- | |
| best_val_acc = -1.0 | |
| epochs_no_improve = 0 | |
| for epoch in range(1, NUM_EPOCHS + 1): | |
| train_loss, train_acc, train_y, train_p = run_epoch( | |
| train_loader, model, criterion, optimizer=optimizer, device=device | |
| ) | |
| val_loss, val_acc, val_y, val_p = run_epoch( | |
| val_loader, model, criterion, optimizer=None, device=device | |
| ) | |
| train_TP, train_FP, train_TN, train_FN = compute_confusion_stats(train_y, train_p) | |
| val_TP, val_FP, val_TN, val_FN = compute_confusion_stats(val_y, val_p) | |
| train_prec, train_rec, train_f1 = prf1_from_confusion(train_TP, train_FP, train_FN) | |
| val_prec, val_rec, val_f1 = prf1_from_confusion(val_TP, val_FP, val_FN) | |
| print( | |
| f"Epoch {epoch:03d}/{NUM_EPOCHS} | " | |
| f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} " | |
| f"train_f1={train_f1:.4f} | " | |
| f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} " | |
| f"val_f1={val_f1:.4f}" | |
| ) | |
| with open(CSV_PATH, "a", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow([epoch, "train", train_acc, train_prec, train_rec, train_f1, train_TP, train_FP, train_TN, train_FN]) | |
| writer.writerow([epoch, "val", val_acc, val_prec, val_rec, val_f1, val_TP, val_FP, val_TN, val_FN]) | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| epochs_no_improve = 0 | |
| checkpoint = {"model_name": MODEL_NAME, "state_dict": model.state_dict()} | |
| torch.save(checkpoint, MPATH) | |
| print(f"Best so far: val_acc={best_val_acc:.4f}. Best model saved to {MPATH}") | |
| else: | |
| epochs_no_improve += 1 | |
| if epochs_no_improve >= PATIENCE: | |
| print(f"Early stopping at epoch {epoch} (best val_acc={best_val_acc:.4f})") | |
| break | |
| with open(CFG_PATH, "w") as f: | |
| json.dump(cfg, f, indent=2, sort_keys=True) | |
| return run_info | |
| if __name__ == "__main__": | |
| train() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # utils.py | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset | |
| from torchvision import models | |
| from torchvision.models import ( | |
| ResNet18_Weights, | |
| ResNet50_Weights, | |
| ViT_B_16_Weights, | |
| ) | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torchvision.transforms.functional as F | |
| MODEL_NAMES = ["resnet18", "resnet50", "vit_b_16"] | |
| def get_label_from_filename(path: str) -> int: | |
| """ | |
| Expects filenames like: | |
| 'FA 17 11 2025_BlobImage1657_rgb_has_emerged_False.png' | |
| 'FP 18 11 2025_BlobImage2007_rgb_has_emerged_True.png' | |
| Returns 1 for emerged, 0 for not emerged. | |
| """ | |
| name = Path(path).stem # drop .png | |
| parts = name.split("_") | |
| flag_str = parts[-1] | |
| if flag_str == "True": | |
| return 1 | |
| elif flag_str == "False": | |
| return 0 | |
| else: | |
| raise ValueError(f"Cannot parse label from filename: {path}") | |
| def compute_confusion_stats(labels, preds): | |
| """ | |
| labels, preds: lists of 0/1 ints. | |
| Returns TP, FP, TN, FN for class 1 ("emerged"). | |
| """ | |
| TP = FP = TN = FN = 0 | |
| for y, p in zip(labels, preds): | |
| if y == 1 and p == 1: | |
| TP += 1 | |
| elif y == 0 and p == 1: | |
| FP += 1 | |
| elif y == 0 and p == 0: | |
| TN += 1 | |
| elif y == 1 and p == 0: | |
| FN += 1 | |
| return TP, FP, TN, FN | |
| def _safe_div(a, b): | |
| return a / b if b != 0 else 0.0 | |
| def prf1_from_confusion(TP, FP, FN): | |
| """ | |
| Precision/recall/F1 for the positive class (emerged = 1). | |
| """ | |
| precision = _safe_div(TP, TP + FP) | |
| recall = _safe_div(TP, TP + FN) | |
| if precision + recall == 0.0: | |
| f1 = 0.0 | |
| else: | |
| f1 = 2 * precision * recall / (precision + recall) | |
| return precision, recall, f1 | |
| def subset(paths, labels, idx_list): | |
| return [paths[i] for i in idx_list], [labels[i] for i in idx_list] | |
| class SeedDataset(Dataset): | |
| def __init__(self, paths, labels, transform=None): | |
| self.paths = paths | |
| self.labels = labels | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, idx): | |
| img_path = self.paths[idx] | |
| label = self.labels[idx] | |
| img = Image.open(img_path).convert("RGB") | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| return img, label | |
| def get_transforms(): | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| #transforms.RandomVerticalFlip(p=0.5), | |
| transforms.ColorJitter( | |
| brightness=0.02, | |
| contrast=0.02, | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| val_test_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| return train_transform, val_test_transform | |
| def make_model(name: str, num_classes: int = 2, pretrained: bool = True): | |
| """ | |
| Factory for known models. Handles architecture-specific classifier heads. | |
| """ | |
| if name == "resnet18": | |
| weights = ResNet18_Weights.DEFAULT if pretrained else None | |
| m = models.resnet18(weights=weights) | |
| m.fc = nn.Linear(m.fc.in_features, num_classes) | |
| elif name == "resnet50": | |
| weights = ResNet50_Weights.DEFAULT if pretrained else None | |
| m = models.resnet50(weights=weights) | |
| m.fc = nn.Linear(m.fc.in_features, num_classes) | |
| elif name == "vit_b_16": | |
| weights = ViT_B_16_Weights.DEFAULT if pretrained else None | |
| m = models.vit_b_16(weights=weights) | |
| # ViT uses heads.head as classifier | |
| m.heads.head = nn.Linear(m.heads.head.in_features, num_classes) | |
| else: | |
| raise ValueError(f"Unknown model name: {name}. Expected one of {MODEL_NAMES}") | |
| return m | |
| def run_epoch(loader, model, criterion, optimizer=None, device="cpu"): | |
| train_mode = optimizer is not None | |
| model.train() if train_mode else model.eval() | |
| total_loss = 0.0 | |
| all_preds, all_labels = [], [] | |
| for imgs, labels in loader: | |
| imgs = imgs.to(device) | |
| labels = labels.to(device) | |
| if train_mode: | |
| optimizer.zero_grad() | |
| with torch.set_grad_enabled(train_mode): | |
| outputs = model(imgs) | |
| loss = criterion(outputs, labels) | |
| if train_mode: | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * imgs.size(0) | |
| preds = outputs.argmax(dim=1) | |
| all_preds.extend(preds.detach().cpu().tolist()) | |
| all_labels.extend(labels.detach().cpu().tolist()) | |
| avg_loss = total_loss / len(all_labels) | |
| accuracy = sum(p == y for p, y in zip(all_preds, all_labels)) / len(all_labels) | |
| return avg_loss, accuracy, all_labels, all_preds | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment