Skip to content

Instantly share code, notes, and snippets.

@mapo80
Last active December 9, 2025 11:58
Show Gist options
  • Select an option

  • Save mapo80/1c91e5e1c1701f5f7efae345fb6c6610 to your computer and use it in GitHub Desktop.

Select an option

Save mapo80/1c91e5e1c1701f5f7efae345fb6c6610 to your computer and use it in GitHub Desktop.
"""
Quantization-Aware Training (QAT) for DocCornerNet.
This script fine-tunes a pre-trained model with fake quantization to simulate
INT8 quantization effects during training, resulting in a model that maintains
accuracy after INT8 conversion.
Usage:
python train_qat.py --checkpoint checkpoints/doccornernet_v2/best.pth \
--output_dir checkpoints/doccornernet_qat \
--num_epochs 10
"""
import argparse
import json
import sys
import copy
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.ao.quantization import (
get_default_qat_qconfig,
prepare_qat,
convert,
QuantStub,
DeQuantStub,
)
from tqdm import tqdm
from model import create_model
from dataset import create_dataloaders
from metrics import ValidationMetrics
class QuantizedDocCornerNet(nn.Module):
"""DocCornerNet with quantization stubs for QAT."""
def __init__(self, model):
super().__init__()
self.quant = QuantStub()
self.model = model
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
coords, score = self.model(x)
coords = self.dequant(coords)
score = self.dequant(score)
return coords, score
def train_epoch(model, dataloader, criterion_coords, criterion_score, optimizer, device, lambda_coords, lambda_score, grad_clip):
"""Train for one epoch."""
model.train()
total_loss = 0
num_batches = 0
pbar = tqdm(dataloader, desc="Training", leave=False)
for batch in pbar:
images = batch["image"].to(device)
gt_coords = batch["coords"].to(device)
gt_has_doc = batch["has_label"].to(device).float()
optimizer.zero_grad()
pred_coords, pred_score = model(images)
# Coordinate loss (only for images with documents)
mask = gt_has_doc.unsqueeze(1)
coords_loss = criterion_coords(pred_coords * mask, gt_coords * mask)
# Score loss
score_loss = criterion_score(pred_score.squeeze(), gt_has_doc)
loss = lambda_coords * coords_loss + lambda_score * score_loss
loss.backward()
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=loss.item())
return total_loss / num_batches
def validate(model, dataloader, device):
"""Validate model."""
model.eval()
metrics = ValidationMetrics()
with torch.no_grad():
for batch in tqdm(dataloader, desc="Validating", leave=False):
images = batch["image"].to(device)
gt_coords = batch["coords"].to(device)
gt_has_doc = batch["has_label"].to(device)
gt_score = batch["score"].to(device)
pred_coords, pred_score = model(images)
metrics.update(
pred_coords=pred_coords,
gt_coords=gt_coords,
pred_scores=pred_score.squeeze(),
gt_scores=gt_score,
has_gt=gt_has_doc,
)
return metrics.compute()
def main():
parser = argparse.ArgumentParser(description="QAT for DocCornerNet")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to pre-trained checkpoint")
parser.add_argument("--data_root", type=str, default="../doc-scanner-dataset-labeled")
parser.add_argument("--output_dir", type=str, default="./checkpoints/doccornernet_qat")
parser.add_argument("--num_epochs", type=int, default=10, help="QAT fine-tuning epochs")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for QAT (lower than normal)")
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--lambda_coords", type=float, default=1.0)
parser.add_argument("--lambda_score", type=float, default=2.0)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--img_size", type=int, default=224)
parser.add_argument("--cache_images", action="store_true", help="Cache images in RAM for faster training")
parser.add_argument("--cache_dir", type=str, default=None, help="Directory for disk cache")
args = parser.parse_args()
# Device
if args.device == "auto":
if torch.cuda.is_available():
device = torch.device("cuda")
else:
# QAT fake quant ops are not implemented on MPS; force CPU.
device = torch.device("cpu")
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load pre-trained model
print(f"\nLoading checkpoint: {args.checkpoint}")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model = create_model(
pretrained=False,
width_mult=checkpoint.get("width_mult", 1.0),
reduced_tail=checkpoint.get("reduced_tail", True),
dropout=checkpoint.get("dropout", 0.2),
)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
# NOTE: XNNPACK replacements are done AFTER QAT training, during export.
# Doing them before QAT breaks FX Graph Mode tracing.
print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
# Create dataloaders
print(f"\nLoading dataset from: {args.data_root}")
train_loader, val_loader = create_dataloaders(
data_root=args.data_root,
img_size=args.img_size,
batch_size=args.batch_size,
num_workers=args.num_workers,
cache_images=args.cache_images,
cache_dir=args.cache_dir,
)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
# Prepare QAT model using Eager Mode (FX Graph Mode breaks MobileNetV3)
print("\nPreparing model for Quantization-Aware Training (Eager Mode)...")
model_cpu = copy.deepcopy(model).cpu()
model_qat = QuantizedDocCornerNet(model_cpu)
model_qat.qconfig = get_default_qat_qconfig("x86")
model_qat = prepare_qat(model_qat, inplace=False)
print("Using Eager Mode QAT")
# Move to device
model_qat = model_qat.to(device)
# Loss functions
criterion_coords = nn.SmoothL1Loss()
criterion_score = nn.BCEWithLogitsLoss()
# Optimizer (use lower LR for QAT)
optimizer = optim.AdamW(model_qat.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Learning rate scheduler - reduce LR when validation plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', factor=0.5, patience=2, min_lr=1e-6
)
# Initial validation
print("\nInitial validation (before QAT):")
val_metrics = validate(model_qat, val_loader, device)
print(f" IoU: {val_metrics['mean_iou']*100:.2f}%")
print(f" Recall@90: {val_metrics['recall_90']*100:.2f}%")
best_iou = val_metrics["mean_iou"]
best_epoch = 0
# QAT Training
print(f"\nStarting QAT for {args.num_epochs} epochs...")
for epoch in range(args.num_epochs):
print(f"\nEpoch {epoch+1}/{args.num_epochs}")
# Train
train_loss = train_epoch(
model_qat, train_loader, criterion_coords, criterion_score,
optimizer, device, args.lambda_coords, args.lambda_score, args.grad_clip
)
print(f" Train Loss: {train_loss:.4f}")
# Validate
val_metrics = validate(model_qat, val_loader, device)
current_iou = val_metrics['mean_iou']
print(f" Val IoU: {current_iou*100:.2f}%")
print(f" Recall@90: {val_metrics['recall_90']*100:.2f}%")
# Step scheduler based on validation IoU
scheduler.step(current_iou)
print(f" LR: {optimizer.param_groups[0]['lr']:.2e}")
# Save best model
if val_metrics["mean_iou"] > best_iou:
best_iou = val_metrics["mean_iou"]
best_epoch = epoch + 1
# Convert to quantized model
model_qat.eval()
torch.save({
"model_state_dict": model_qat.state_dict(),
"epoch": epoch + 1,
"best_iou": best_iou,
"val_metrics": val_metrics,
}, output_dir / "best_qat.pth")
print(f" >> New best model saved! IoU: {best_iou*100:.2f}%")
print(f"\nQAT completed! Best IoU: {best_iou*100:.2f}% at epoch {best_epoch}")
# Convert final model
print("\nConverting QAT model to quantized model...")
model_qat.eval()
# Reload best checkpoint if available
best_path = output_dir / "best_qat.pth"
if best_path.exists():
state = torch.load(best_path, map_location="cpu")
model_qat.load_state_dict(state["model_state_dict"])
print(f"Loaded best QAT weights from epoch {state.get('epoch', '?')}")
model_qat_cpu = model_qat.cpu()
try:
model_quantized = convert_fx(model_qat_cpu)
print("FX conversion successful")
except Exception as e:
print(f"FX conversion failed: {e}")
model_quantized = convert(model_qat_cpu, inplace=False)
print("Eager mode conversion successful")
# Save quantized model
torch.save(model_quantized.state_dict(), output_dir / "model_quantized.pth")
print(f"Quantized model saved to: {output_dir / 'model_quantized.pth'}")
# Export to TorchScript for potential deployment
try:
scripted = torch.jit.script(model_quantized)
scripted.save(str(output_dir / "model_quantized.pt"))
print(f"TorchScript model saved to: {output_dir / 'model_quantized.pt'}")
except Exception as e:
print(f"TorchScript export failed: {e}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment