Last active
December 9, 2025 11:58
-
-
Save mapo80/1c91e5e1c1701f5f7efae345fb6c6610 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Quantization-Aware Training (QAT) for DocCornerNet. | |
| This script fine-tunes a pre-trained model with fake quantization to simulate | |
| INT8 quantization effects during training, resulting in a model that maintains | |
| accuracy after INT8 conversion. | |
| Usage: | |
| python train_qat.py --checkpoint checkpoints/doccornernet_v2/best.pth \ | |
| --output_dir checkpoints/doccornernet_qat \ | |
| --num_epochs 10 | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| import copy | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.ao.quantization import ( | |
| get_default_qat_qconfig, | |
| prepare_qat, | |
| convert, | |
| QuantStub, | |
| DeQuantStub, | |
| ) | |
| from tqdm import tqdm | |
| from model import create_model | |
| from dataset import create_dataloaders | |
| from metrics import ValidationMetrics | |
| class QuantizedDocCornerNet(nn.Module): | |
| """DocCornerNet with quantization stubs for QAT.""" | |
| def __init__(self, model): | |
| super().__init__() | |
| self.quant = QuantStub() | |
| self.model = model | |
| self.dequant = DeQuantStub() | |
| def forward(self, x): | |
| x = self.quant(x) | |
| coords, score = self.model(x) | |
| coords = self.dequant(coords) | |
| score = self.dequant(score) | |
| return coords, score | |
| def train_epoch(model, dataloader, criterion_coords, criterion_score, optimizer, device, lambda_coords, lambda_score, grad_clip): | |
| """Train for one epoch.""" | |
| model.train() | |
| total_loss = 0 | |
| num_batches = 0 | |
| pbar = tqdm(dataloader, desc="Training", leave=False) | |
| for batch in pbar: | |
| images = batch["image"].to(device) | |
| gt_coords = batch["coords"].to(device) | |
| gt_has_doc = batch["has_label"].to(device).float() | |
| optimizer.zero_grad() | |
| pred_coords, pred_score = model(images) | |
| # Coordinate loss (only for images with documents) | |
| mask = gt_has_doc.unsqueeze(1) | |
| coords_loss = criterion_coords(pred_coords * mask, gt_coords * mask) | |
| # Score loss | |
| score_loss = criterion_score(pred_score.squeeze(), gt_has_doc) | |
| loss = lambda_coords * coords_loss + lambda_score * score_loss | |
| loss.backward() | |
| if grad_clip > 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| num_batches += 1 | |
| pbar.set_postfix(loss=loss.item()) | |
| return total_loss / num_batches | |
| def validate(model, dataloader, device): | |
| """Validate model.""" | |
| model.eval() | |
| metrics = ValidationMetrics() | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, desc="Validating", leave=False): | |
| images = batch["image"].to(device) | |
| gt_coords = batch["coords"].to(device) | |
| gt_has_doc = batch["has_label"].to(device) | |
| gt_score = batch["score"].to(device) | |
| pred_coords, pred_score = model(images) | |
| metrics.update( | |
| pred_coords=pred_coords, | |
| gt_coords=gt_coords, | |
| pred_scores=pred_score.squeeze(), | |
| gt_scores=gt_score, | |
| has_gt=gt_has_doc, | |
| ) | |
| return metrics.compute() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="QAT for DocCornerNet") | |
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to pre-trained checkpoint") | |
| parser.add_argument("--data_root", type=str, default="../doc-scanner-dataset-labeled") | |
| parser.add_argument("--output_dir", type=str, default="./checkpoints/doccornernet_qat") | |
| parser.add_argument("--num_epochs", type=int, default=10, help="QAT fine-tuning epochs") | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for QAT (lower than normal)") | |
| parser.add_argument("--weight_decay", type=float, default=1e-4) | |
| parser.add_argument("--lambda_coords", type=float, default=1.0) | |
| parser.add_argument("--lambda_score", type=float, default=2.0) | |
| parser.add_argument("--grad_clip", type=float, default=1.0) | |
| parser.add_argument("--num_workers", type=int, default=4) | |
| parser.add_argument("--device", type=str, default="auto") | |
| parser.add_argument("--img_size", type=int, default=224) | |
| parser.add_argument("--cache_images", action="store_true", help="Cache images in RAM for faster training") | |
| parser.add_argument("--cache_dir", type=str, default=None, help="Directory for disk cache") | |
| args = parser.parse_args() | |
| # Device | |
| if args.device == "auto": | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| # QAT fake quant ops are not implemented on MPS; force CPU. | |
| device = torch.device("cpu") | |
| else: | |
| device = torch.device(args.device) | |
| print(f"Using device: {device}") | |
| # Create output directory | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Load pre-trained model | |
| print(f"\nLoading checkpoint: {args.checkpoint}") | |
| checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) | |
| model = create_model( | |
| pretrained=False, | |
| width_mult=checkpoint.get("width_mult", 1.0), | |
| reduced_tail=checkpoint.get("reduced_tail", True), | |
| dropout=checkpoint.get("dropout", 0.2), | |
| ) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model = model.to(device) | |
| # NOTE: XNNPACK replacements are done AFTER QAT training, during export. | |
| # Doing them before QAT breaks FX Graph Mode tracing. | |
| print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters") | |
| # Create dataloaders | |
| print(f"\nLoading dataset from: {args.data_root}") | |
| train_loader, val_loader = create_dataloaders( | |
| data_root=args.data_root, | |
| img_size=args.img_size, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| cache_images=args.cache_images, | |
| cache_dir=args.cache_dir, | |
| ) | |
| print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") | |
| # Prepare QAT model using Eager Mode (FX Graph Mode breaks MobileNetV3) | |
| print("\nPreparing model for Quantization-Aware Training (Eager Mode)...") | |
| model_cpu = copy.deepcopy(model).cpu() | |
| model_qat = QuantizedDocCornerNet(model_cpu) | |
| model_qat.qconfig = get_default_qat_qconfig("x86") | |
| model_qat = prepare_qat(model_qat, inplace=False) | |
| print("Using Eager Mode QAT") | |
| # Move to device | |
| model_qat = model_qat.to(device) | |
| # Loss functions | |
| criterion_coords = nn.SmoothL1Loss() | |
| criterion_score = nn.BCEWithLogitsLoss() | |
| # Optimizer (use lower LR for QAT) | |
| optimizer = optim.AdamW(model_qat.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
| # Learning rate scheduler - reduce LR when validation plateaus | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode='max', factor=0.5, patience=2, min_lr=1e-6 | |
| ) | |
| # Initial validation | |
| print("\nInitial validation (before QAT):") | |
| val_metrics = validate(model_qat, val_loader, device) | |
| print(f" IoU: {val_metrics['mean_iou']*100:.2f}%") | |
| print(f" Recall@90: {val_metrics['recall_90']*100:.2f}%") | |
| best_iou = val_metrics["mean_iou"] | |
| best_epoch = 0 | |
| # QAT Training | |
| print(f"\nStarting QAT for {args.num_epochs} epochs...") | |
| for epoch in range(args.num_epochs): | |
| print(f"\nEpoch {epoch+1}/{args.num_epochs}") | |
| # Train | |
| train_loss = train_epoch( | |
| model_qat, train_loader, criterion_coords, criterion_score, | |
| optimizer, device, args.lambda_coords, args.lambda_score, args.grad_clip | |
| ) | |
| print(f" Train Loss: {train_loss:.4f}") | |
| # Validate | |
| val_metrics = validate(model_qat, val_loader, device) | |
| current_iou = val_metrics['mean_iou'] | |
| print(f" Val IoU: {current_iou*100:.2f}%") | |
| print(f" Recall@90: {val_metrics['recall_90']*100:.2f}%") | |
| # Step scheduler based on validation IoU | |
| scheduler.step(current_iou) | |
| print(f" LR: {optimizer.param_groups[0]['lr']:.2e}") | |
| # Save best model | |
| if val_metrics["mean_iou"] > best_iou: | |
| best_iou = val_metrics["mean_iou"] | |
| best_epoch = epoch + 1 | |
| # Convert to quantized model | |
| model_qat.eval() | |
| torch.save({ | |
| "model_state_dict": model_qat.state_dict(), | |
| "epoch": epoch + 1, | |
| "best_iou": best_iou, | |
| "val_metrics": val_metrics, | |
| }, output_dir / "best_qat.pth") | |
| print(f" >> New best model saved! IoU: {best_iou*100:.2f}%") | |
| print(f"\nQAT completed! Best IoU: {best_iou*100:.2f}% at epoch {best_epoch}") | |
| # Convert final model | |
| print("\nConverting QAT model to quantized model...") | |
| model_qat.eval() | |
| # Reload best checkpoint if available | |
| best_path = output_dir / "best_qat.pth" | |
| if best_path.exists(): | |
| state = torch.load(best_path, map_location="cpu") | |
| model_qat.load_state_dict(state["model_state_dict"]) | |
| print(f"Loaded best QAT weights from epoch {state.get('epoch', '?')}") | |
| model_qat_cpu = model_qat.cpu() | |
| try: | |
| model_quantized = convert_fx(model_qat_cpu) | |
| print("FX conversion successful") | |
| except Exception as e: | |
| print(f"FX conversion failed: {e}") | |
| model_quantized = convert(model_qat_cpu, inplace=False) | |
| print("Eager mode conversion successful") | |
| # Save quantized model | |
| torch.save(model_quantized.state_dict(), output_dir / "model_quantized.pth") | |
| print(f"Quantized model saved to: {output_dir / 'model_quantized.pth'}") | |
| # Export to TorchScript for potential deployment | |
| try: | |
| scripted = torch.jit.script(model_quantized) | |
| scripted.save(str(output_dir / "model_quantized.pt")) | |
| print(f"TorchScript model saved to: {output_dir / 'model_quantized.pt'}") | |
| except Exception as e: | |
| print(f"TorchScript export failed: {e}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment