Skip to content

Instantly share code, notes, and snippets.

@sepiabrown
Last active November 27, 2025 11:06
Show Gist options
  • Select an option

  • Save sepiabrown/b9dc366643a48285fdea331dccc5377e to your computer and use it in GitHub Desktop.

Select an option

Save sepiabrown/b9dc366643a48285fdea331dccc5377e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Consolidated Anomaly Detector with Batched TTA
This module provides a modern, consolidated implementation of anomaly detection
with DeCo-Diff using efficient batched Test-Time Augmentation (TTA).
Key Features:
- Combined Batch TTA: Processes B images × N shifts together (~6.6× speedup)
- Configurable shifts: pad_px, stride, directions parameters
- Clean architecture without legacy code
- Designed to replace DecodiffEvaluator and DecodiffEvaluateProcessor
"""
import torch
import numpy as np
import pandas as pd
from skimage.transform import resize
from skimage import measure
from sklearn.metrics import average_precision_score, auc, roc_auc_score, f1_score
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
from glob import glob
from pathlib import Path
import json
import matplotlib
matplotlib.use('Agg') # Force non-GUI backend for thread safety (prevents TkAgg threading issues on Windows)
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image as PILImage
import os
import cv2
import sys
import logging
import time
import copy
from typing import Optional, List
# Configure UTF-8 encoding for output (Windows compatibility)
if sys.platform.startswith('win'):
import io
# Only wrap if not already a TextIOWrapper to avoid double-wrapping
if not isinstance(sys.stdout, io.TextIOWrapper):
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
if not isinstance(sys.stderr, io.TextIOWrapper):
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
from dioodmi.models import UNET_models
from dioodmi.algorithms.diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from dioodmi.data.MVTECDataLoader import MVTECDataset
from dioodmi.data.VISADataLoader import VISADataset
from scipy.ndimage import gaussian_filter
from dioodmi.utils.plot_results import PlotResults as PR
from dioodmi.utils.utils import setup_signal_handlers, setup_logging
from dioodmi.data.temporal_dataset_factory import (
create_temporal_dataset_manager,
create_temporal_dataloader
)
from dioodmi.tasks.anomaly_detection.full_image_tiled_evaluator import (
process_full_image_evaluation_loop,
stack_anomaly_maps,
setup_resume_mode
)
from typing import Tuple, List, Dict, Set, Any
# ============================================================================
# Image Processing Utility Functions (Pure Functions)
# ============================================================================
def path_to_safe_filename(image_path: str) -> str:
"""Convert file path to safe filename for saving"""
safe_name = os.path.basename(image_path)
# Remove extension and replace problematic characters
safe_name = os.path.splitext(safe_name)[0]
safe_name = safe_name.replace(' ', '_').replace('/', '_').replace('\\', '_')
return safe_name
def draw_patch_rectangles_on_image(base_img: np.ndarray, predicted_defective_set: set,
ground_truth_defective: set, overlapping: set,
patch_size: int = 256, grid_thickness: int = 1) -> np.ndarray:
"""Draw patch rectangles (TP/FP/FN) on top of an image"""
img_copy = base_img.copy()
h, w = img_copy.shape[:2]
# Color definitions for different patch types
colors = {
'TP': (0, 255, 0), # Green for True Positives (overlapping)
'FP': (255, 0, 0), # Red for False Positives (predicted only)
'FN': (0, 0, 255), # Blue for False Negatives (ground truth only)
}
def draw_rectangle(img, row, col, color, thickness):
"""Draw a single rectangle on image"""
x1 = col * patch_size
y1 = row * patch_size
x2 = min(x1 + patch_size, w)
y2 = min(y1 + patch_size, h)
cv2.rectangle(img, (x1, y1), (x2-1, y2-1), color, thickness)
# Draw True Positives (overlapping patches) - Green
for row, col in overlapping:
draw_rectangle(img_copy, row, col, colors['TP'], grid_thickness)
# Draw False Positives (predicted but not in ground truth) - Red
false_positives = predicted_defective_set - ground_truth_defective
for row, col in false_positives:
draw_rectangle(img_copy, row, col, colors['FP'], grid_thickness)
# Draw False Negatives (ground truth but not predicted) - Blue
false_negatives = ground_truth_defective - predicted_defective_set
for row, col in false_negatives:
draw_rectangle(img_copy, row, col, colors['FN'], grid_thickness)
return img_copy
def create_anomaly_visualization(anomaly_map: np.ndarray, is_binary: bool = False,
threshold: float = 5.0, add_grid: bool = True,
patch_size: int = 256) -> np.ndarray:
"""Create anomaly map visualization"""
if is_binary:
# Binary visualization: threshold and convert to inverted colormap
# Auto-detect anomaly map range and normalize threshold accordingly
amap_max = np.max(anomaly_map)
if amap_max <= 1.0:
# Anomaly map is in 0-1 range, normalize threshold from 0-255 to 0-1
normalized_threshold = threshold / 255.0
else:
# Anomaly map is in 0-255 range, use threshold as-is
normalized_threshold = threshold
binary_map = (anomaly_map > normalized_threshold).astype(np.uint8)
# Create RGB image: black for 0 (normal), white for 1 (anomaly)
vis_img = np.zeros((binary_map.shape[0], binary_map.shape[1], 3), dtype=np.uint8) # Start with black
vis_img[binary_map == 1] = [255, 255, 255] # Set anomaly pixels to white
return vis_img
else:
# Continuous visualization: normalize and apply colormap
normalized = np.clip(anomaly_map, 0, 1)
# Convert to 0-255 range and apply jet colormap
normalized_255 = (normalized * 255).astype(np.uint8)
colored = cv2.applyColorMap(normalized_255, cv2.COLORMAP_JET)
return cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
def create_anomaly_overlay(original_img: np.ndarray, anomaly_map: np.ndarray,
alpha: float = 0.8, is_binary: bool = False,
threshold: float = 5.0) -> np.ndarray:
"""Create anomaly overlay on original image"""
# Create anomaly visualization
anomaly_vis = create_anomaly_visualization(anomaly_map, is_binary, threshold, add_grid=False)
# Ensure both images have same dimensions
h, w = original_img.shape[:2]
if anomaly_vis.shape[:2] != (h, w):
from skimage.transform import resize
anomaly_vis = resize(anomaly_vis, (h, w), preserve_range=True, anti_aliasing=True).astype(np.uint8)
# Create overlay
overlay = cv2.addWeighted(original_img, alpha, anomaly_vis, 1-alpha, 0)
return overlay
def determine_image_status(patch_results: list) -> str:
"""Determine overall image status from patch results"""
status_counts = {'TP': 0, 'FN': 0, 'FP': 0, 'TN': 0}
for result in patch_results:
status = result.get('status', 'TN')
if status in status_counts:
status_counts[status] += 1
# Image is positive if it has any defective patches (TP or FN)
has_ground_truth_defects = status_counts['TP'] + status_counts['FN'] > 0
has_predicted_defects = status_counts['TP'] + status_counts['FP'] > 0
if has_ground_truth_defects and has_predicted_defects:
return 'TP' # Correctly identified as defective
elif has_ground_truth_defects and not has_predicted_defects:
return 'FN' # Missed defective image
elif not has_ground_truth_defects and has_predicted_defects:
return 'FP' # False alarm
else:
return 'TN' # Correctly identified as normal
def _generate_shifts(
pad_px: int,
stride: int,
directions: tuple
) -> List[Tuple[int, int]]:
"""Generate shift coordinates based on parameters.
Args:
pad_px: Shift range (±pad_px pixels)
stride: Shift increment (1=every pixel, 2=skip pixels)
directions: ("h", "v", "diag") | ("all",) | ("rhombus",)
Returns:
Sorted list of (dx, dy) shift coordinates
Number of shifts generated:
- ("h", "v", "diag") with pad_px=4, stride=1: 33 shifts
- ("h", "v", "diag") with pad_px=2, stride=1: 13 shifts
- ("h", "v") with pad_px=4, stride=1: 17 shifts
- ("all",) with pad_px=4, stride=1: 81 shifts
- ("all",) with pad_px=4, stride=2: 25 shifts
"""
shifts = set()
if "all" in directions or directions == "all":
for dx in range(-pad_px, pad_px + 1, stride):
for dy in range(-pad_px, pad_px + 1, stride):
shifts.add((dx, dy))
elif "rhombus" in directions or directions == "rhombus":
for dx in range(-pad_px, pad_px + 1, stride):
for dy in range(-pad_px, pad_px + 1, stride):
if abs(dx) + abs(dy) <= (pad_px + 1):
shifts.add((dx, dy))
else:
if "h" in directions:
for dx in range(-pad_px, pad_px + 1, stride):
shifts.add((dx, 0))
if "v" in directions:
for dy in range(-pad_px, pad_px + 1, stride):
shifts.add((0, dy))
if "diag" in directions:
for d in range(-pad_px, pad_px + 1, stride):
shifts.add((d, d))
shifts.add((d, -d))
# Always include center shift
shifts.add((0, 0))
return sorted(shifts)
def _get_shift_type_and_amount(dx: int, dy: int) -> Tuple[str, int]:
"""Determine shift type and amount from (dx, dy) coordinates.
Args:
dx: Horizontal shift in pixels
dy: Vertical shift in pixels
Returns:
Tuple[str, int]: (shift_type, shift_amount)
shift_type: 'none', 'h' (horizontal), 'v' (vertical), 'diag' (diagonal)
shift_amount: Pixel offset amount
Examples:
>>> _get_shift_type_and_amount(0, 0)
('none', 0)
>>> _get_shift_type_and_amount(3, 0)
('h', 3)
>>> _get_shift_type_and_amount(0, -2)
('v', -2)
>>> _get_shift_type_and_amount(2, 2)
('diag', 2)
>>> _get_shift_type_and_amount(-3, 3)
('diag', 3)
"""
if dx == 0 and dy == 0:
return 'none', 0
elif dy == 0: # Pure horizontal
return 'h', dx
elif dx == 0: # Pure vertical
return 'v', dy
else: # Diagonal (dx != 0 and dy != 0)
# Use max absolute value as amount, preserve sign
max_abs = max(abs(dx), abs(dy))
# Determine sign: use sign of the component with larger absolute value
if abs(dx) > abs(dy):
sign = 1 if dx > 0 else -1
else:
sign = 1 if dy > 0 else -1
return 'diag', sign * max_abs
def create_confusion_matrix_from_patch_results(
patch_results_by_image: dict,
output_dir: str,
patch_size: int = 256
) -> None:
"""
Create confusion matrix visualization from patch results.
Args:
patch_results_by_image: Dict mapping image_path -> list of patch results
Each patch result has: grid_row, grid_col, anomaly_pixels, pred_score, status
output_dir: Directory to save confusion matrix plot and JSON
patch_size: Size of patches for grid coordinate calculation
"""
import json
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
# Collect all patch results
all_patch_results = []
for image_path, patches in patch_results_by_image.items():
all_patch_results.extend(patches)
if not all_patch_results:
print("No patch results provided for confusion matrix generation")
return
# Initialize confusion matrix counters
all_TP = all_FP = all_FN = all_TN = 0
print(f"Creating confusion matrix for {len(patch_results_by_image)} images...")
# Count status values from patch results
for patch in all_patch_results:
status = patch.get('status', 'TN')
if status == "TP":
all_TP += 1
elif status == "FP":
all_FP += 1
elif status == "FN":
all_FN += 1
elif status == "TN":
all_TN += 1
total = all_TP + all_FP + all_FN + all_TN
accuracy = (all_TP + all_TN) / total if total > 0 else 0
# Calculate additional metrics
precision = all_TP / (all_TP + all_FP) if (all_TP + all_FP) > 0 else 0
recall = all_TP / (all_TP + all_FN) if (all_TP + all_FN) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
print("Confusion Matrix (patch-level):")
print(f"TP: {all_TP}, FP: {all_FP}, FN: {all_FN}, TN: {all_TN}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1_score:.4f}")
# Create confusion matrix visualization
cm = np.array([[all_TP, all_FN], [all_FP, all_TN]])
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap='Blues')
plt.title('Confusion Matrix (Patch-level)', fontsize=16, fontweight='bold')
plt.colorbar()
# Add text annotations
thresh = cm.max() / 2.
for i in range(2):
for j in range(2):
plt.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=14, fontweight='bold')
# Set labels
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Defective', 'Normal'], fontsize=12)
plt.yticks(tick_marks, ['Defective', 'Normal'], fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
# Add metrics text
metrics_text = f'Accuracy: {accuracy:.4f}\nPrecision: {precision:.4f}\nRecall: {recall:.4f}\nF1-Score: {f1_score:.4f}'
plt.figtext(0.02, 0.02, metrics_text, fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
plt.tight_layout()
# Save the confusion matrix plot
os.makedirs(output_dir, exist_ok=True)
cm_plot_path = os.path.join(output_dir, "confusion_matrix.png")
plt.savefig(cm_plot_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Confusion matrix plot saved to: {cm_plot_path}")
# Save detailed results to file
result = {
"TP": all_TP,
"FP": all_FP,
"FN": all_FN,
"TN": all_TN,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"total_patches": total
}
with open(os.path.join(output_dir, "confusion_matrix.json"), "w") as f:
json.dump(result, f, indent=2)
print(f"Confusion matrix results saved to: {os.path.join(output_dir, 'confusion_matrix.json')}")
# ============================================================================
# Checkpoint Manager
# ============================================================================
class CheckpointManager:
"""Manages checkpoint and organized image saving functionality"""
def __init__(self, results_dir: str, enable_checkpointing: bool = False):
self.results_dir = Path(results_dir)
self.enable_checkpointing = enable_checkpointing
self.marked_images_dir = self.results_dir / "marked_images"
self.evaluation_results_dir = self.results_dir / "evaluation_results"
# Create base directories
self.marked_images_dir.mkdir(parents=True, exist_ok=True)
self.evaluation_results_dir.mkdir(parents=True, exist_ok=True)
# Setup organized folder structure
self._setup_directory_structure()
def _setup_directory_structure(self):
"""Create only base directory structure - subfolders created on demand"""
# Only create the base marked_images_dir and evaluation_results_dir
# Subfolders will be created on-demand when files are actually saved
pass
def get_status_folder(self, status: str) -> Path:
"""Get path for status-specific folder"""
return self.marked_images_dir / status
def get_image_level_folder(self) -> Path:
"""Get path for image_level folder"""
return self.marked_images_dir / 'image_level'
# ============================================================================
# Anomaly Detector (Consolidated Class)
# ============================================================================
class AnomalyDetector:
"""
Modern Anomaly Detector with Batched TTA
This class consolidates functionality from DecodiffEvaluator and
DecodiffEvaluateProcessor into a single, efficient implementation.
Key Features:
- Batched TTA for ~6.6× speedup (Option B: Combined Batch)
- Configurable shift parameters (pad_px, stride, directions)
- Clean architecture without legacy code
"""
def __init__(self, args):
# Setup DDP-compatible attributes (for signal handler compatibility)
# AnomalyDetector doesn't use DDP, but we set these for consistency
self.ddp = False
self.local_rank = 0
self.is_main_process = True # Always True for single-process evaluation
# Setup logging early (before other operations that might log)
self.setup_logging(args)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cpu":
self.log_message("⚠ GPU not found. Using CPU instead.", level='warning')
# Set AMP dtype for mixed precision
self.amp_dtype = (
torch.bfloat16
if (self.device == "cuda" and torch.cuda.is_bf16_supported())
else torch.float16
)
# Store difference augmentation parameters from args
# clip_max is used to clip and normalize (divide by clip_max) to maintain output range [0, 1.0]
self.image_diff_clip_max = getattr(args, 'image_diff_clip_max', 1.0)
self.latent_diff_clip_max = getattr(args, 'latent_diff_clip_max', 1.0)
# Store anomaly detection parameters from args
self.anomaly_threshold = getattr(args, 'anomaly_threshold', 20)
self.anomaly_min_area = getattr(args, 'anomaly_min_area', 10)
# Validate clip_max values at initialization
if self.image_diff_clip_max == 0:
raise ValueError("image_diff_clip_max cannot be 0 (division by zero)")
if self.latent_diff_clip_max == 0:
raise ValueError("latent_diff_clip_max cannot be 0 (division by zero)")
torch.set_grad_enabled(False)
# Store filename strategy configuration
self.regression_test_mode = getattr(args, 'regression_test_mode', False)
self.object_class = getattr(args, 'object_category', None)
self.filename_strategy_name = getattr(args, 'filename_strategy', None)
self.setup_model(args)
# File saving control (can be disabled for performance testing)
self._disable_file_saving = False # Set to True to disable all file saving for performance testing
self.save_plot_path = None if args.save_plot_path is None else Path(args.save_plot_path)
if self.save_plot_path and not self._disable_file_saving:
self.save_plot_path.mkdir(parents=True, exist_ok=True)
# Initialize checkpoint manager for image saving if enabled
# Note: CheckpointManager is used for saving evaluation images, so we enable it
# when save_evaluation_images is True, regardless of checkpoint_enabled setting
# TEMPORARY: Disable for performance testing
self.checkpoint_manager = None
if args.save_evaluation_images and self.save_plot_path and not self._disable_file_saving:
self.checkpoint_manager = CheckpointManager(
results_dir=str(self.save_plot_path),
enable_checkpointing=True # Always enable for evaluation image saving
)
# Setup save queue and worker thread if save_plot_path is set
self.save_queue = None
self.save_worker_thread = None
if self.save_plot_path and not self._disable_file_saving:
self._start_save_worker()
# Setup signal handlers for graceful shutdown (using utility function)
setup_signal_handlers(
self,
ddp=self.ddp,
is_main_process=self.is_main_process,
log_message_func=self.log_message # Use logging for AnomalyDetector
)
def setup_logging(self, args):
"""Setup logging with proper Unicode support (using utility function)"""
# Determine log file path - use save_plot_path if available, otherwise use current directory
if hasattr(args, 'save_plot_path') and args.save_plot_path:
log_dir = Path(args.save_plot_path).parent
log_file_path = log_dir / "log.txt"
else:
# Fallback to current directory
log_file_path = Path("log.txt")
# Check for debug flag in args
debug_enabled = getattr(args, 'debug', False) or getattr(args, 'debug_enabled', False)
# Use utility function for logging setup
self.logger = setup_logging(
log_file_path=log_file_path,
logger_name=__name__,
is_main_process=self.is_main_process, # Always True for AnomalyDetector
debug_enabled=debug_enabled
)
# Store debug flag for conditional timing logs
self.debug_enabled = debug_enabled
# Log initialization message
self.logger.info(f"Logging initialized - logs saved to: {log_file_path}")
def log_message(self, message, level='info'):
"""
Log message to both console and file.
Args:
message: Message to log
level: Log level ('info', 'warning', 'error', 'debug')
"""
if level == 'info':
self.logger.info(message)
elif level == 'warning':
self.logger.warning(message)
elif level == 'error':
self.logger.error(message)
elif level == 'debug':
self.logger.debug(message)
else:
self.logger.info(message) # Default to info
def setup_model(self, args):
"""Setup model and VAE for evaluation"""
# Load VAE
import os
models_path = "./models"
models_config = os.path.join(models_path, "config.json")
current_dir = os.getcwd()
self.log_message(f"\n{'='*80}")
self.log_message(f"[DEBUG] VAE Model Loading")
self.log_message(f"{'='*80}")
self.log_message(f"Current working directory: {current_dir}")
self.log_message(f"Checking for local models at: {os.path.abspath(models_path)}")
self.log_message(f"Looking for config at: {os.path.abspath(models_config)}")
self.log_message(f"Config exists: {os.path.exists(models_config)}")
# Check current directory first
if os.path.exists(models_config):
self.log_message(f"✓ Found local VAE model, loading from: {os.path.abspath(models_path)}")
vae_model = models_path
self.vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(self.device)
else:
# Fallback: check parent directory
parent_models_path = "../models"
parent_models_config = os.path.join(parent_models_path, "config.json")
self.log_message(f"Checking parent directory for models at: {os.path.abspath(parent_models_path)}")
self.log_message(f"Looking for config at: {os.path.abspath(parent_models_config)}")
self.log_message(f"Config exists: {os.path.exists(parent_models_config)}")
if os.path.exists(parent_models_config):
self.log_message(f"✓ Found local VAE model in parent directory, loading from: {os.path.abspath(parent_models_path)}")
vae_model = parent_models_path
self.vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(self.device)
else:
self.log_message(f"✗ Local VAE not found in current or parent directory, downloading from HuggingFace")
vae_model = f"stabilityai/sd-vae-ft-{args.vae_type}"
self.log_message(f"Downloading: {vae_model}")
self.vae = AutoencoderKL.from_pretrained(vae_model).to(self.device)
self.log_message(f"{'='*80}\n")
self.vae.eval()
# Find checkpoint
try:
if args.model_path != '':
ckpt = args.model_path
else:
path = f"./DeCo-Diff_{args.dataset}_{args.object_category}_{args.model_size}_{args.crop_size}"
try:
ckpt = sorted(glob(f'{path}/last.pt'))[-1]
except:
ckpt = sorted(glob(f'{path}/*/last.pt'))[-1]
except:
raise Exception("Please provide the trained model's path using --model_path")
# Setup model
latent_size = int(args.crop_size) // 8
self.model = UNET_models[args.model_size](latent_size=latent_size, ncls=args.num_classes)
# Load checkpoint with automatic device mapping (CUDA if available, else CPU)
map_location = self.device if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load(ckpt, weights_only=False, map_location=map_location)
# Check if EMA weights are available and use them if present
if 'ema_state_dict' in checkpoint:
# EMA state dict contains 'ema_model_state_dict' with the actual weights
ema_state_dict = checkpoint['ema_state_dict']
if 'ema_model_state_dict' in ema_state_dict:
state_dict = ema_state_dict['ema_model_state_dict']
self.log_message('Using EMA model weights for evaluation')
else:
# Fallback: assume ema_state_dict is the model state dict directly
state_dict = ema_state_dict
self.log_message('Using EMA model weights for evaluation (direct format)')
else:
# Fall back to regular model weights
state_dict = checkpoint['model']
self.log_message('Using regular model weights (no EMA found in checkpoint)')
load_result = self.model.load_state_dict(state_dict)
self.log_message(str(load_result))
self.model.eval()
self.model.to(self.device)
self.log_message('✓ Model loaded')
def _start_save_worker(self):
"""Start background worker thread for concurrent image saving"""
from queue import Queue
import threading
from PIL import Image as PILImage
self.save_queue = Queue()
def worker():
"""Background thread that processes image save requests from queue"""
while True:
try:
item = self.save_queue.get()
if item is None: # Sentinel value to stop worker
break
kind = item.get('kind', 'array')
if kind == 'array':
filepath = item['filepath']
image_array = item['image_array']
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
PILImage.fromarray(image_array, mode='L').save(filepath)
elif kind == 'plot':
# Handle plot saving in background thread
from dioodmi.utils.plot_results import PlotResults as PR
# Convert numpy arrays back to torch tensors (plot_shift_analysis expects tensors for orig_imgs)
orig_imgs = item['orig_imgs']
if isinstance(orig_imgs, np.ndarray):
orig_imgs = torch.from_numpy(orig_imgs).float()
# averaged_anomaly_map can stay as numpy (it's used directly in imshow, not passed to rescale_range)
averaged_anomaly_map = item['averaged_anomaly_map']
# Convert any numpy arrays in shift data back to tensors (for crop_original/crop_reconstruction)
individual_shifts = item['individual_shifts']
for shift_data in individual_shifts:
if 'crop_original' in shift_data and shift_data['crop_original'] is not None:
if isinstance(shift_data['crop_original'], np.ndarray):
shift_data['crop_original'] = torch.from_numpy(shift_data['crop_original']).float()
if 'crop_reconstruction' in shift_data and shift_data['crop_reconstruction'] is not None:
if isinstance(shift_data['crop_reconstruction'], np.ndarray):
shift_data['crop_reconstruction'] = torch.from_numpy(shift_data['crop_reconstruction']).float()
PR.plot_shift_analysis(
individual_shifts=individual_shifts,
averaged_anomaly_map=averaged_anomaly_map,
batch_idx=item['batch_idx'],
orig_imgs=orig_imgs,
save_path=item['save_path'],
max_shifts_to_show=item.get('max_shifts_to_show', 16),
fuse_method=item.get('fuse_method', 'mean')
)
else:
print(f"Unknown save job kind: {kind}")
except Exception as e:
print(f"Error in save worker: {e}")
import traceback
traceback.print_exc()
finally:
try:
self.save_queue.task_done()
except Exception as e:
print(f"Error calling task_done: {e}")
self.save_worker_thread = threading.Thread(target=worker, daemon=True)
self.save_worker_thread.start()
self.log_message("Started background image save worker thread")
def wait_for_saves(self, timeout=300):
"""
Wait for all queued save operations to complete.
Args:
timeout: Maximum time to wait in seconds (default: 5 minutes)
"""
import time
if not self.save_queue or not self.is_save_worker_alive():
return
start_time = time.time()
queue_size = self.get_save_queue_size()
if queue_size > 0:
print(f"Waiting for {queue_size} save operations to complete...")
# Wait for queue to be empty
while not self.save_queue.empty():
if self._shutdown_requested:
print("Shutdown requested, stopping wait for saves...")
break
if time.time() - start_time > timeout:
print(f"Warning: Save operations timed out after {timeout} seconds")
print(f"Remaining queue size: {self.get_save_queue_size()}")
break
time.sleep(0.1) # Small delay to prevent busy waiting
# Wait for all tasks to complete
try:
self.save_queue.join()
if queue_size > 0:
print("All save operations completed successfully")
except Exception as e:
print(f"Error waiting for saves to complete: {e}")
def is_save_worker_alive(self):
"""Check if the save worker thread is still alive"""
return self.save_worker_thread and self.save_worker_thread.is_alive()
def get_save_queue_size(self):
"""Get the current size of the save queue"""
return self.save_queue.qsize() if self.save_queue else 0
def shutdown_save_worker(self):
"""Stop the save worker thread with adaptive timeout based on queue size"""
if self.save_queue and self.save_worker_thread:
queue_size = self.get_save_queue_size()
# Adaptive timeout: 5s base + 0.1s per queued item (min 10s)
# Example: 200 items → 25s timeout, 500 items → 55s timeout
timeout = max(10, queue_size * 0.1 + 5)
if queue_size > 0:
self.log_message(f"Stopping save worker ({queue_size} items remaining, timeout={timeout:.0f}s)...")
self.save_queue.put(None) # Send sentinel to stop worker
self.save_worker_thread.join(timeout=timeout)
if self.save_worker_thread.is_alive():
remaining = self.get_save_queue_size()
self.log_message(f"⚠️ Warning: Save worker timeout after {timeout:.0f}s ({remaining} items not saved)", level='warning')
else:
if queue_size > 0:
self.log_message(f"✓ Save worker stopped ({queue_size} items saved)")
else:
self.log_message("Save worker thread stopped")
def compute_pro(self, masks, amaps, num_th=200):
"""Compute the area under the curve of per-region overlapping (PRO) and 0 to 0.3 FPR"""
assert isinstance(amaps, np.ndarray), "type(amaps) must be ndarray"
assert isinstance(masks, np.ndarray), "type(masks) must be ndarray"
assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)"
assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)"
assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same"
unique_values = set(masks.flatten())
assert unique_values.issubset({0, 1, 0.0, 1.0}), f"set(masks.flatten()) must be subset of {{0, 1}}, got {unique_values}"
assert isinstance(num_th, int), "type(num_th) must be int"
df = pd.DataFrame([], columns=["pro", "fpr", "threshold"])
binary_amaps = np.zeros_like(amaps, dtype=bool)
min_th = amaps.min()
max_th = amaps.max()
# Handle case where all anomaly map values are the same
if min_th == max_th:
# Return 0.0 since AUPRO integration between 0 and 0.3 for (0,0)-(1,1) curve is 0
return 0.0
# Use linspace for better numerical stability instead of arange
# Generate thresholds from min_th to max_th (inclusive) with num_th points
thresholds = np.linspace(min_th, max_th, num_th, endpoint=False)
for th in thresholds:
binary_amaps[amaps <= th] = 0
binary_amaps[amaps > th] = 1
pros = []
for binary_amap, mask in zip(binary_amaps, masks):
for region in measure.regionprops(measure.label(mask)):
axes0_ids = region.coords[:, 0]
axes1_ids = region.coords[:, 1]
tp_pixels = binary_amap[axes0_ids, axes1_ids].sum()
pros.append(tp_pixels / region.area)
inverse_masks = 1 - masks
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
fpr = fp_pixels / inverse_masks.sum()
df.loc[len(df)] = [np.mean(pros) if len(pros) > 0 else 0, fpr, th]
# Normalize FPR from 0 ~ 1 to 0 ~ 0.3
df = df[df["fpr"] < 0.3]
if df.empty:
return 0.0
df["fpr"] = df["fpr"] / df["fpr"].max()
# Calculate area under PRO-FPR curve
pro_auc = auc(df["fpr"], df["pro"])
return pro_auc
def compute_pro_parallel(self, masks, amaps, num_th=200, n_jobs=-1):
"""Phase 5: Parallel PRO curve computation (1.3× faster).
Computes per-region overlap (PRO) curve using parallel processing across thresholds.
Args:
masks: Ground truth masks [N, H, W]
amaps: Anomaly maps [N, H, W]
num_th: Number of thresholds (default: 200)
n_jobs: Number of parallel jobs (-1 = all cores, default: -1)
Returns:
float: Area under PRO-FPR curve (0 to 0.3 FPR range)
"""
from joblib import Parallel, delayed
from skimage import measure
assert isinstance(amaps, np.ndarray), "type(amaps) must be ndarray"
assert isinstance(masks, np.ndarray), "type(masks) must be ndarray"
assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)"
assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)"
assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same"
unique_values = set(masks.flatten())
assert unique_values.issubset({0, 1, 0.0, 1.0}), f"set(masks.flatten()) must be subset of {{0, 1}}, got {unique_values}"
assert isinstance(num_th, int), "type(num_th) must be int"
min_th = amaps.min()
max_th = amaps.max()
# Handle case where all anomaly map values are the same
if min_th == max_th:
return 0.0
thresholds = np.linspace(min_th, max_th, num_th, endpoint=False)
def _process_threshold(th):
"""Process single threshold to compute PRO and FPR."""
# Binarize anomaly maps
binary_amaps = (amaps > th).astype(bool)
# Compute PRO for each region
pros = []
for binary_amap, mask in zip(binary_amaps, masks):
for region in measure.regionprops(measure.label(mask)):
axes0_ids = region.coords[:, 0]
axes1_ids = region.coords[:, 1]
tp_pixels = binary_amap[axes0_ids, axes1_ids].sum()
pros.append(tp_pixels / region.area)
# Compute FPR
inverse_masks = 1 - masks
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
fpr = fp_pixels / inverse_masks.sum()
return {
'pro': np.mean(pros) if len(pros) > 0 else 0,
'fpr': fpr,
'threshold': th
}
# Parallel processing across thresholds
results = Parallel(n_jobs=n_jobs, prefer='threads')(
delayed(_process_threshold)(th) for th in thresholds
)
# Convert to DataFrame
df = pd.DataFrame(results)
# Normalize FPR from 0 ~ 1 to 0 ~ 0.3
df = df[df["fpr"] < 0.3]
if df.empty:
return 0.0
df["fpr"] = df["fpr"] / df["fpr"].max()
# Calculate area under PRO-FPR curve
pro_auc = auc(df["fpr"], df["pro"])
return pro_auc
def calculate_metrics(self, masks, amaps, parallel_pro=True):
"""Calculate AUROC, AUPRO, and F1-max metrics
Args:
masks: Ground truth masks
amaps: Anomaly maps
parallel_pro: Use parallel PRO computation (default: True)
"""
# Convert to numpy for sklearn metrics
masks_flat = masks.flatten()
amaps_flat = amaps.flatten()
# Pixel-level metrics
# Use sklearn roc_auc_score instead of anomalib (API compatibility)
if len(set(masks_flat)) > 1: # Need both classes
auroc_score = roc_auc_score(masks_flat, amaps_flat)
else:
auroc_score = 0.0
# F1 score (pixel-level) - using median threshold
threshold = np.median(amaps_flat)
pred_labels = (amaps_flat > threshold).astype(int)
if len(set(masks_flat)) > 1:
f1_score_val = f1_score(masks_flat, pred_labels)
else:
f1_score_val = 0.0
ap_score = average_precision_score(masks_flat, amaps_flat) if len(set(masks_flat)) > 1 else 0.0
# Sample-level metrics (image-level)
labels = []
scores = []
for i in range(len(masks)):
labels.append(1 if masks[i].max() > 0 else 0)
scores.append(amaps[i].max())
if len(set(labels)) > 1:
auroc_sp = roc_auc_score(labels, scores)
else:
auroc_sp = 0.0
ap_sp = average_precision_score(labels, scores) if len(set(labels)) > 1 else 0.0
# F1 score (sample-level) - using median threshold
threshold = np.median(scores)
pred_labels = [1 if s > threshold else 0 for s in scores]
if len(set(labels)) > 1:
f1_sp = f1_score(labels, pred_labels)
else:
f1_sp = 0.0
# AUPRO - use parallel version if enabled
if parallel_pro:
aupro_score = self.compute_pro_parallel(masks, amaps)
else:
aupro_score = self.compute_pro(masks, amaps)
return auroc_score, aupro_score, f1_score_val, ap_score, auroc_sp, ap_sp, f1_sp
def smooth_mask(self, mask, sigma=3):
"""Apply Gaussian smoothing to mask"""
return gaussian_filter(mask, sigma=sigma)
def calculate_anomaly_maps_batch_gpu(self, x0, encoded, image_samples, latent_samples, crop_size=256,
image_diff_clip_max=1.0, latent_diff_clip_max=1.0):
"""Phase 3: GPU-accelerated anomaly map calculation (1.5-2× faster).
Keeps all tensors on GPU, uses torch.nn.functional.interpolate for resizing.
Returns torch tensors instead of numpy arrays for downstream GPU processing.
Args:
x0: Original decoded images [B, C, H, W] (GPU tensor)
encoded: Encoded latents [B, C, H', W'] (GPU tensor)
image_samples: Reconstructed images [B, C, H, W] (GPU tensor)
latent_samples: Reconstructed latents [B, C, H', W'] (GPU tensor)
crop_size: Target size for latent resizing
image_diff_clip_max: Clipping threshold for image difference
latent_diff_clip_max: Clipping threshold for latent difference
Returns:
dict: Anomaly maps as GPU tensors [B, H, W]
"""
import torch
import torch.nn.functional as F
# Validate clip_max values
if image_diff_clip_max == 0:
raise ValueError("image_diff_clip_max cannot be 0")
if latent_diff_clip_max == 0:
raise ValueError("latent_diff_clip_max cannot be 0")
# Image difference (keep on GPU)
image_difference = torch.abs(image_samples - x0).float().mean(dim=1) # [B, H, W]
image_difference = torch.clamp(image_difference, 0.0, image_diff_clip_max) / image_diff_clip_max
# Latent difference with GPU resize
latent_difference = torch.abs(latent_samples - encoded).float().mean(dim=1) # [B, H', W']
latent_difference = torch.clamp(latent_difference, 0.0, latent_diff_clip_max) / latent_diff_clip_max
# Resize latent to match image size using GPU interpolation
latent_difference = F.interpolate(
latent_difference.unsqueeze(1), # [B, 1, H', W']
size=(crop_size, crop_size),
mode='bilinear',
align_corners=False
).squeeze(1) # [B, H, W]
# Compute anomaly maps (all GPU ops)
pred_geometric = torch.sqrt(image_difference * latent_difference)
pred_arithmetic = 0.5 * image_difference + 0.5 * latent_difference
# Return GPU tensors (no .cpu().numpy() conversion!)
return {
'anomaly_geometric': pred_geometric, # [B, H, W] torch tensor on GPU
'anomaly_arithmetic': pred_arithmetic, # [B, H, W] torch tensor on GPU
'latent_discrepancy': latent_difference, # [B, H, W] torch tensor on GPU
'image_discrepancy': image_difference, # [B, H, W] torch tensor on GPU
}
def calculate_anomaly_maps_batch(self, x0, encoded, image_samples, latent_samples, crop_size=256,
image_diff_clip_max=1.0, latent_diff_clip_max=1.0):
"""Legacy CPU-based anomaly map calculation.
Converts to numpy arrays immediately. Kept for backward compatibility.
Args:
x0: Original decoded images
encoded: Encoded latents
image_samples: Reconstructed images from diffusion
latent_samples: Reconstructed latents from diffusion
crop_size: Crop size for resizing
image_diff_clip_max: Maximum value to clip image difference to (scale = 1.0 / clip_max)
latent_diff_clip_max: Maximum value to clip latent difference to (scale = 1.0 / clip_max)
"""
B = x0.shape[0]
# Validate clip_max values
if image_diff_clip_max == 0:
raise ValueError("image_diff_clip_max cannot be 0 (division by zero)")
if latent_diff_clip_max == 0:
raise ValueError("latent_diff_clip_max cannot be 0 (division by zero)")
# Absolute differences (for anomaly detection)
image_difference_raw = torch.abs(image_samples - x0).float().mean(dim=1)
image_difference = image_difference_raw.detach().cpu().numpy()
# Apply clipping and normalize by dividing by clip_max (maintains output range [0, 1.0])
# Example: clip_max=0.4 -> np.clip(image_difference, 0.0, 0.4) / 0.4
image_difference = np.clip(image_difference, 0.0, image_diff_clip_max) / image_diff_clip_max
image_differences = image_difference
latent_difference = torch.abs(latent_samples - encoded).float().mean(dim=1)
latent_difference = latent_difference.detach().cpu().numpy()
# Apply clipping and normalize by dividing by clip_max (maintains output range [0, 1.0])
# Example: clip_max=0.2 -> np.clip(latent_difference, 0.0, 0.2) / 0.2
latent_difference = np.clip(latent_difference, 0.0, latent_diff_clip_max) / latent_diff_clip_max
latent_differences = []
for i in range(B):
resized = resize(latent_difference[i], (crop_size, crop_size))
latent_differences.append(resized)
latent_differences = np.stack(latent_differences, axis=0)
# Signed differences (for grayscale visualization: use raw values, clip to -1~1 -> 0~1)
# Compute on GPU, then move to CPU only when needed
image_signed_diff = (image_samples - x0).float().mean(dim=1)
# Clip to [-1, 1] range on GPU (images are in [-1, 1] range, so difference can be [-2, 2])
image_signed_diff = torch.clamp(image_signed_diff, -1.0, 1.0).detach().cpu().numpy()
latent_signed_diff = (latent_samples - encoded).float().mean(dim=1)
# Clip on GPU, then resize on CPU
latent_signed_diff = torch.clamp(latent_signed_diff, -1.0, 1.0).detach().cpu().numpy()
latent_signed_diffs = []
for i in range(B):
resized = resize(latent_signed_diff[i], (crop_size, crop_size))
latent_signed_diffs.append(resized)
latent_signed_diffs = np.stack(latent_signed_diffs, axis=0)
pred_geometric = np.sqrt(image_differences * latent_differences)
pred_arithmetic = 0.5 * image_differences + 0.5 * latent_differences
return {
'anomaly_geometric': pred_geometric,
'anomaly_arithmetic': pred_arithmetic,
'latent_discrepancy': latent_differences,
'image_discrepancy': image_differences,
'image_signed_diff': image_signed_diff, # For grayscale: -1~1 range
'latent_signed_diff': latent_signed_diffs # For grayscale: -1~1 range
}
def _anomaly_shift_avg_combined_batch(
self,
x_batch: torch.Tensor,
*,
model_kwargs: dict,
diffusion,
reverse_steps: int,
pad_px: int = 4,
img_enlarge_px: int = 4,
stride: int = 1,
directions: tuple = ("h", "v", "diag"),
shift_method: str = 'interpolate',
fuse_method: str = 'mean',
tta_batch_size: int = 8,
eta: float = 0.0,
crop_size: int = 256,
) -> Tuple[Dict[str, np.ndarray], List[List[Dict]]]:
"""Option B: Combined Batch TTA for maximum speedup.
Processes B DataLoader images × N shifts in combined batches of size tta_batch_size.
This batches BOTH the image dimension AND the shift dimension together.
Args:
x_batch: Input images [B, 3, H, W] from DataLoader
model_kwargs: Model parameters (context, mask)
diffusion: Diffusion model
reverse_steps: Number of DDIM steps
pad_px: Shift range (±pad_px pixels)
img_enlarge_px: Image enlargement for interpolation
stride: Shift increment
directions: Shift directions to use
shift_method: 'interpolate' or 'mirror'
fuse_method: How to aggregate shifts ('mean', 'median', 'lowest', 'pct25', 'pct75')
tta_batch_size: Number of shifts to process together per batch
eta: DDIM eta parameter (0=deterministic)
crop_size: Final crop size
Returns:
Tuple of:
- Dictionary with keys:
- 'anomaly_geometric': [B, H, W]
- 'anomaly_arithmetic': [B, H, W]
- 'latent_discrepancy': [B, H, W]
- 'image_discrepancy': [B, H, W]
- List[List[Dict]]: individual_shifts[img_idx] = list of shift dicts for plotting
"""
device = self.device
B, C, H, W = x_batch.shape
img_size = H
# Validate pad_px
if pad_px > img_enlarge_px:
pad_px = img_enlarge_px
print(f"⚠ Warning: pad_px clamped to img_enlarge_px={img_enlarge_px}")
# Generate all shifts
shifts = _generate_shifts(pad_px, stride, directions)
num_shifts = len(shifts)
# Use print() for high-frequency progress messages to avoid file I/O overhead
print(f"Processing {B} images × {num_shifts} shifts in batches of {tta_batch_size}")
# Prepare upsampled/padded input
if shift_method == 'interpolate':
aug_size = img_size + 2 * img_enlarge_px
offset = img_enlarge_px
x_up = F.interpolate(
x_batch, size=(aug_size, aug_size),
mode='bilinear', align_corners=False
).to(device)
elif shift_method == 'mirror':
aug_size = img_size + 2 * pad_px
offset = pad_px
x_up = F.pad(
x_batch, (pad_px, pad_px, pad_px, pad_px),
mode='reflect'
).to(device)
else:
raise ValueError(f"Unknown shift_method: {shift_method}. Choose 'interpolate' or 'mirror'")
# Storage for all images × all shifts
keys = ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy']
signed_keys = ['image_signed_diff', 'latent_signed_diff']
# Structure: all_maps[key][img_idx][shift_idx] = canvas
all_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in keys}
all_signed_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in signed_keys}
# Storage for individual shifts (for plot_shift_analysis)
# Structure: individual_shifts[img_idx] = list of shift dicts
individual_shifts = [[] for _ in range(B)]
# Process shifts in mini-batches
num_batches = (num_shifts + tta_batch_size - 1) // tta_batch_size
for batch_idx in range(num_batches):
batch_start_time = time.time()
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ TTA processing interrupted at batch {batch_idx}/{num_batches}. Shutting down...")
break
batch_start = batch_idx * tta_batch_size
batch_end = min(batch_start + tta_batch_size, num_shifts)
batch_shifts = shifts[batch_start:batch_end]
print(f" Batch {batch_idx+1}/{num_batches}: processing shifts {batch_start}-{batch_end-1}")
# Create combined batch: [B × len(batch_shifts), 3, H, W]
prep_start = time.time()
combined_crops = []
metadata = [] # Track (img_idx, shift_idx_in_full_list, shift, position)
for img_idx in range(B):
for shift_local_idx, (dx, dy) in enumerate(batch_shifts):
shift_global_idx = batch_start + shift_local_idx
y0 = offset + dy
x0 = offset + dx
y1 = y0 + img_size
x1 = x0 + img_size
# Skip invalid crops
if not (0 <= y0 < y1 <= aug_size and 0 <= x0 < x1 <= aug_size):
continue
crop = x_up[img_idx:img_idx+1, :, y0:y1, x0:x1]
combined_crops.append(crop)
metadata.append({
'img_idx': img_idx,
'shift_idx': shift_global_idx,
'shift': (dx, dy),
'position': (x0, y0, x1, y1)
})
if not combined_crops:
continue
# Stack: [B*tta_batch_size, 3, H, W]
combined_tensor = torch.cat(combined_crops, dim=0)
prep_time = time.time() - prep_start
if self.debug_enabled:
self.log_message(f" [TIMING] Batch prep: {prep_time:.3f}s", level='debug')
# Expand context tensor to match combined batch size
context_start = time.time()
# Each crop needs its corresponding category index
combined_batch_size = combined_tensor.shape[0]
if 'context' in model_kwargs and model_kwargs['context'] is not None:
# Get original context shape [B, 1]
original_context = model_kwargs['context']
# Validate original context
if original_context is None:
raise ValueError(f"original_context is None in batch {batch_idx}")
if not isinstance(original_context, torch.Tensor):
raise TypeError(f"original_context must be Tensor, got {type(original_context)}")
# Build expanded context based on metadata to handle skipped crops correctly
# metadata contains the img_idx for each crop in the combined batch
expanded_context_list = []
for meta in metadata:
img_idx = meta['img_idx']
# Validate img_idx is within bounds
if img_idx >= original_context.shape[0]:
raise IndexError(f"img_idx {img_idx} >= original_context.shape[0] {original_context.shape[0]}")
# Get the context for this image
img_context = original_context[img_idx:img_idx+1] # [1, 1]
expanded_context_list.append(img_context)
if expanded_context_list:
expanded_context = torch.cat(expanded_context_list, dim=0) # [combined_batch_size, 1]
else:
# Fallback if no metadata (shouldn't happen, but be safe)
num_shifts_in_batch = len(batch_shifts)
expanded_context = original_context.repeat_interleave(num_shifts_in_batch, dim=0)
if expanded_context.shape[0] != combined_batch_size:
expanded_context = expanded_context[:combined_batch_size]
# Validate expanded context
if expanded_context is None:
raise ValueError(f"expanded_context is None in batch {batch_idx}")
if expanded_context.shape[0] != combined_batch_size:
raise ValueError(f"expanded_context.shape[0] {expanded_context.shape[0]} != combined_batch_size {combined_batch_size}")
model_kwargs_batch = {
'context': expanded_context,
'mask': model_kwargs.get('mask', None)
}
else:
# If no context in model_kwargs, this is an error for class-conditional models
raise ValueError(f"model_kwargs['context'] is missing or None in batch {batch_idx} - required for class-conditional model")
context_time = time.time() - context_start
if self.debug_enabled:
self.log_message(f" [TIMING] Context expansion: {context_time:.3f}s", level='debug')
# Process entire combined batch at once (FAST!)
gpu_start = time.time()
with torch.no_grad():
encode_start = time.time()
encoded = self.vae.encode(combined_tensor).latent_dist.mean.mul_(0.18215)
if torch.cuda.is_available():
torch.cuda.synchronize()
encode_time = time.time() - encode_start
if self.debug_enabled:
self.log_message(f" [TIMING] VAE encode: {encode_time:.3f}s", level='debug')
diffusion_start = time.time()
latent_samples = diffusion.ddim_deviation_sample_loop(
self.model, encoded.shape, noise=encoded,
clip_denoised=False, start_t=reverse_steps,
model_kwargs=model_kwargs_batch, progress=False,
device=device, eta=eta
)
if torch.cuda.is_available():
torch.cuda.synchronize()
diffusion_time = time.time() - diffusion_start
if self.debug_enabled:
self.log_message(f" [TIMING] Diffusion sampling: {diffusion_time:.3f}s", level='debug')
decode_start = time.time()
image_samples = self.vae.decode(latent_samples / 0.18215).sample
x0_decoded = self.vae.decode(encoded / 0.18215).sample
if torch.cuda.is_available():
torch.cuda.synchronize()
decode_time = time.time() - decode_start
if self.debug_enabled:
self.log_message(f" [TIMING] VAE decode: {decode_time:.3f}s", level='debug')
gpu_time = time.time() - gpu_start
if self.debug_enabled:
self.log_message(f" [TIMING] Total GPU: {gpu_time:.3f}s", level='debug')
# Calculate anomaly maps for entire batch
anomaly_start = time.time()
anomaly_batch = self.calculate_anomaly_maps_batch(
x0_decoded, encoded, image_samples, latent_samples,
crop_size=img_size,
image_diff_clip_max=self.image_diff_clip_max,
latent_diff_clip_max=self.latent_diff_clip_max
)
anomaly_time = time.time() - anomaly_start
if self.debug_enabled:
self.log_message(f" [TIMING] Anomaly map calc: {anomaly_time:.3f}s", level='debug')
# Distribute results back to each image
distribute_start = time.time()
for i, meta in enumerate(metadata):
img_idx = meta['img_idx']
shift_idx = meta['shift_idx']
x0, y0, x1, y1 = meta['position']
# Place on canvas for this image/shift
for key in keys:
canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32)
canvas[y0:y1, x0:x1] = anomaly_batch[key][i]
all_maps[key][img_idx][shift_idx] = canvas
# Place signed differences on canvas
for key in signed_keys:
if key in anomaly_batch:
canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32)
canvas[y0:y1, x0:x1] = anomaly_batch[key][i]
all_signed_maps[key][img_idx][shift_idx] = canvas
# Store individual shift data for plot_shift_analysis
# Get the crop and reconstruction for this shift
# Note: combined_crops[i] corresponds to the crop at index i in the combined batch
if i < len(combined_crops):
crop_original = combined_crops[i].detach().cpu()
else:
crop_original = None
if i < image_samples.shape[0]:
crop_reconstruction = image_samples[i:i+1].detach().cpu()
else:
crop_reconstruction = None
individual_shifts[img_idx].append({
'shift': meta['shift'], # Use shift from metadata, not undefined (dx, dy)
'position': (x0, y0, x1, y1),
'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][i].copy(),
'crop_original': crop_original,
'crop_reconstruction': crop_reconstruction,
'shift_method': shift_method
})
distribute_time = time.time() - distribute_start
if self.debug_enabled:
self.log_message(f" [TIMING] Result distribution: {distribute_time:.3f}s", level='debug')
batch_total_time = time.time() - batch_start_time
if self.debug_enabled:
self.log_message(f" [TIMING] Batch {batch_idx+1} total: {batch_total_time:.3f}s", level='debug')
# Average shifts for each image
fuse_start = time.time()
final_anomaly_maps = {}
# Optimize: Batch resize operations on GPU if using interpolate
if shift_method == 'interpolate':
# Collect all averaged maps first, then batch resize on GPU
all_averaged_maps = {} # key -> list of [H, W] arrays
for key in keys:
all_averaged_maps[key] = []
for img_idx in range(B):
# Stack all shifts for this image (filter out None entries)
valid_maps = [m for m in all_maps[key][img_idx] if m is not None]
if not valid_maps:
# No valid shifts (shouldn't happen)
all_averaged_maps[key].append(np.zeros((aug_size, aug_size), dtype=np.float32))
continue
stack = np.stack(valid_maps, axis=-1) # [H, W, num_valid_shifts]
# Fuse across shifts using optimized partition-based method
averaged = self._fuse_array(stack, fuse_method, axis=-1)
all_averaged_maps[key].append(averaged)
# Batch resize all maps on GPU at once (much faster)
for key in keys:
# Stack all averaged maps: [B, H, W]
stacked_averaged = np.stack(all_averaged_maps[key], axis=0)
# Convert to tensor and resize on GPU
averaged_tensor = torch.from_numpy(stacked_averaged).float().unsqueeze(1).to(self.device) # [B, 1, H, W]
resized_tensor = F.interpolate(
averaged_tensor, size=(img_size, img_size),
mode='bilinear', align_corners=False
)
# Move back to CPU and remove channel dimension
final_anomaly_maps[key] = resized_tensor.squeeze(1).cpu().numpy() # [B, H, W]
else:
# Mirror method: just slice, no resize needed (fast path)
for key in keys:
final_maps = []
for img_idx in range(B):
# Stack all shifts for this image (filter out None entries)
valid_maps = [m for m in all_maps[key][img_idx] if m is not None]
if not valid_maps:
# No valid shifts (shouldn't happen)
final_maps.append(np.zeros((img_size, img_size), dtype=np.float32))
continue
stack = np.stack(valid_maps, axis=-1) # [H, W, num_valid_shifts]
# Fuse across shifts using optimized partition-based method
averaged = self._fuse_array(stack, fuse_method, axis=-1)
# Just slice the valid region (no resize needed for mirror)
resized = averaged[pad_px:pad_px+img_size, pad_px:pad_px+img_size]
final_maps.append(resized)
final_anomaly_maps[key] = np.stack(final_maps, axis=0)
fuse_time = time.time() - fuse_start
if self.debug_enabled:
self.log_message(f" [TIMING] Shift fusion: {fuse_time:.3f}s", level='debug')
# Average signed differences for each image
for key in signed_keys:
final_maps = []
for img_idx in range(B):
# Stack all shifts for this image (filter out None entries)
valid_maps = [m for m in all_signed_maps[key][img_idx] if m is not None]
if not valid_maps:
# No valid shifts (shouldn't happen)
final_maps.append(np.zeros((img_size, img_size), dtype=np.float32))
continue
stack = np.stack(valid_maps, axis=-1) # [H, W, num_valid_shifts]
# Fuse across shifts (use mean for signed differences)
averaged = np.nanmean(stack, axis=-1)
# Resize back to original size
if shift_method == 'interpolate':
averaged_tensor = torch.from_numpy(averaged).float().unsqueeze(0).unsqueeze(0)
resized_tensor = F.interpolate(
averaged_tensor, size=(img_size, img_size),
mode='bilinear', align_corners=False
)
resized = resized_tensor.squeeze().numpy()
else:
resized = averaged[pad_px:pad_px+img_size, pad_px:pad_px+img_size]
final_maps.append(resized)
final_anomaly_maps[key] = np.stack(final_maps, axis=0)
return final_anomaly_maps, individual_shifts
def _compute_tta_per_shift_fundamentals(
self,
x_batch: torch.Tensor,
*,
model_kwargs: dict,
diffusion,
reverse_steps: int,
pad_px: int = 4,
img_enlarge_px: int = 4,
stride: int = 1,
directions: tuple = ("h", "v", "diag"),
shift_method: str = 'interpolate',
tta_batch_size: int = 8,
eta: float = 0.0,
crop_size: int = 256,
) -> Dict[str, Any]:
"""Compute per-shift fundamentals for TTA (Test-Time Augmentation).
Similar to _anomaly_shift_avg_combined_batch but stores fundamentals per-shift
instead of computing and fusing anomaly maps.
Args:
x_batch: Input images [B, 3, H, W] from DataLoader
model_kwargs: Model parameters (context, mask)
diffusion: Diffusion model
reverse_steps: Number of DDIM steps
pad_px: Shift range (±pad_px pixels)
img_enlarge_px: Image enlargement for interpolation
stride: Shift increment
directions: Shift directions to use
shift_method: 'interpolate' or 'mirror'
tta_batch_size: Number of shifts to process together per batch
eta: DDIM eta parameter (0=deterministic)
crop_size: Final crop size
Returns:
Dictionary with keys:
- 'fundamentals': Dict[str, List[List[np.ndarray]]]
Format: fundamentals[key][img_idx][shift_idx] = [1, C, H, W]
Keys: 'x0', 'encoded', 'image_samples', 'latent_samples'
- 'shifts': List[(dx, dy)] - shift coordinates
- 'num_shifts': int - total number of shifts
- 'directions': tuple - shift directions used
- 'pad_px': int - shift range used
- 'stride': int - stride used
"""
device = self.device
B, C, H, W = x_batch.shape
img_size = H
# Validate pad_px
if pad_px > img_enlarge_px:
pad_px = img_enlarge_px
print(f"⚠ Warning: pad_px clamped to img_enlarge_px={img_enlarge_px}")
# Generate all shifts
shifts = _generate_shifts(pad_px, stride, directions)
num_shifts = len(shifts)
self.log_message(f"Computing per-shift fundamentals for {B} images × {num_shifts} shifts")
# Prepare upsampled/padded input
if shift_method == 'interpolate':
aug_size = img_size + 2 * img_enlarge_px
offset = img_enlarge_px
x_up = F.interpolate(
x_batch, size=(aug_size, aug_size),
mode='bilinear', align_corners=False
).to(device)
elif shift_method == 'mirror':
aug_size = img_size + 2 * pad_px
offset = pad_px
x_up = F.pad(
x_batch, (pad_px, pad_px, pad_px, pad_px),
mode='reflect'
).to(device)
else:
raise ValueError(f"Unknown shift_method: {shift_method}. Choose 'interpolate' or 'mirror'")
# Storage for all images × all shifts
# Structure: all_fundamentals[key][img_idx][shift_idx] = np.ndarray [1, C, H, W]
all_fundamentals = {
'x0': [[None for _ in range(num_shifts)] for _ in range(B)],
'encoded': [[None for _ in range(num_shifts)] for _ in range(B)],
'image_samples': [[None for _ in range(num_shifts)] for _ in range(B)],
'latent_samples': [[None for _ in range(num_shifts)] for _ in range(B)],
}
# Process shifts in mini-batches
num_batches = (num_shifts + tta_batch_size - 1) // tta_batch_size
for batch_idx in range(num_batches):
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ TTA fundamentals computation interrupted at batch {batch_idx}/{num_batches}. Shutting down...")
break
batch_start = batch_idx * tta_batch_size
batch_end = min(batch_start + tta_batch_size, num_shifts)
batch_shifts = shifts[batch_start:batch_end]
print(f" Batch {batch_idx+1}/{num_batches}: processing shifts {batch_start}-{batch_end-1}")
# Create combined batch: [B × len(batch_shifts), 3, H, W]
combined_crops = []
metadata = [] # Track (img_idx, shift_idx_in_full_list, shift, position)
for img_idx in range(B):
for shift_local_idx, (dx, dy) in enumerate(batch_shifts):
shift_global_idx = batch_start + shift_local_idx
y0 = offset + dy
x0 = offset + dx
y1 = y0 + img_size
x1 = x0 + img_size
# Skip invalid crops
if not (0 <= y0 < y1 <= aug_size and 0 <= x0 < x1 <= aug_size):
continue
crop = x_up[img_idx:img_idx+1, :, y0:y1, x0:x1]
combined_crops.append(crop)
metadata.append({
'img_idx': img_idx,
'shift_idx': shift_global_idx,
'shift': (dx, dy),
'position': (x0, y0, x1, y1)
})
if not combined_crops:
continue
# Stack: [B*tta_batch_size, 3, H, W]
combined_tensor = torch.cat(combined_crops, dim=0)
# Expand context tensor to match combined batch size
combined_batch_size = combined_tensor.shape[0]
if 'context' in model_kwargs and model_kwargs['context'] is not None:
original_context = model_kwargs['context']
# Validate original context
if original_context is None:
raise ValueError(f"original_context is None in batch {batch_idx}")
if not isinstance(original_context, torch.Tensor):
raise TypeError(f"original_context must be Tensor, got {type(original_context)}")
# Build expanded context based on metadata
expanded_context_list = []
for meta in metadata:
img_idx = meta['img_idx']
if img_idx >= original_context.shape[0]:
raise IndexError(f"img_idx {img_idx} >= original_context.shape[0] {original_context.shape[0]}")
img_context = original_context[img_idx:img_idx+1] # [1, 1]
expanded_context_list.append(img_context)
if expanded_context_list:
expanded_context = torch.cat(expanded_context_list, dim=0)
else:
num_shifts_in_batch = len(batch_shifts)
expanded_context = original_context.repeat_interleave(num_shifts_in_batch, dim=0)
if expanded_context.shape[0] != combined_batch_size:
expanded_context = expanded_context[:combined_batch_size]
# Validate expanded context
if expanded_context is None:
raise ValueError(f"expanded_context is None in batch {batch_idx}")
if expanded_context.shape[0] != combined_batch_size:
raise ValueError(f"expanded_context.shape[0] {expanded_context.shape[0]} != combined_batch_size {combined_batch_size}")
model_kwargs_batch = {
'context': expanded_context,
'mask': model_kwargs.get('mask', None)
}
else:
raise ValueError(f"model_kwargs['context'] is missing or None in batch {batch_idx}")
# Process entire combined batch at once
with torch.no_grad():
encoded = self.vae.encode(combined_tensor).latent_dist.mean.mul_(0.18215)
latent_samples = diffusion.ddim_deviation_sample_loop(
self.model, encoded.shape, noise=encoded,
clip_denoised=False, start_t=reverse_steps,
model_kwargs=model_kwargs_batch, progress=False,
device=device, eta=eta
)
image_samples = self.vae.decode(latent_samples / 0.18215).sample
x0_decoded = self.vae.decode(encoded / 0.18215).sample
# Store fundamentals per shift (NEW - this is the key difference!)
for i, meta in enumerate(metadata):
img_idx = meta['img_idx']
shift_idx = meta['shift_idx']
# Store each fundamental as [1, C, H, W] (single image)
all_fundamentals['x0'][img_idx][shift_idx] = x0_decoded[i:i+1].cpu().numpy()
all_fundamentals['encoded'][img_idx][shift_idx] = encoded[i:i+1].cpu().numpy()
all_fundamentals['image_samples'][img_idx][shift_idx] = image_samples[i:i+1].cpu().numpy()
all_fundamentals['latent_samples'][img_idx][shift_idx] = latent_samples[i:i+1].cpu().numpy()
# Return per-shift fundamentals with metadata
return {
'fundamentals': all_fundamentals,
'shifts': shifts,
'num_shifts': num_shifts,
'directions': directions,
'pad_px': pad_px,
'stride': stride
}
def _evaluate_full_image_tiled(self, args, enable_resume=False):
"""
Evaluate full images using tiled approach.
This method handles large images by tiling, reconstructing, and stitching.
"""
# Initialize diffusion
diffusion = create_diffusion(str(args.reverse_steps))
# Get categories (handle both single category and list)
categories = getattr(args, 'categories', [args.object_category])
if isinstance(categories, str):
categories = [categories]
# Storage for all results across categories
all_category_anomaly_maps = {}
all_category_masks = {}
all_category_labels = {}
# Simple plot save function (can be extended)
def queue_plot_save_batch_or_individual(results_list, img_paths, base_filepath,
anomaly_threshold=None, min_area=None,
use_batch_plot=True):
# Placeholder - can be implemented if needed
# Use instance variables if not provided
if anomaly_threshold is None:
anomaly_threshold = self.anomaly_threshold
if min_area is None:
min_area = self.anomaly_min_area
pass
queue_plot_save_func = queue_plot_save_batch_or_individual if self.save_plot_path else None
for category in categories:
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ Full image evaluation interrupted at category {category}. Shutting down...")
break
self.log_message(f"\n{'='*80}")
self.log_message(f"Evaluating category: {category}")
self.log_message(f"{'='*80}\n")
# Setup resume mode
json_output_path = None
existing_json_data = None
processed_image_paths = set()
json_output_path, existing_json_data, processed_image_paths, self.save_plot_path = \
setup_resume_mode(
enable_resume=enable_resume,
category=category,
save_plot_path=self.save_plot_path,
data_dir=args.data_dir,
json_output_path=None
)
# Create temporal dataset manager
test_dataset_manager = create_temporal_dataset_manager(
dataset_name=args.dataset,
mode='test',
object_class=category,
rootdir=args.data_dir,
image_size=getattr(args, 'dataset_image_size', args.image_size),
crop_size=args.crop_size,
anomaly_class=getattr(args, 'anomaly_class', None),
csv_split_file=getattr(args, 'csv_split_file', None)
)
# Create temporal dataloader
test_loader = create_temporal_dataloader(
dataset_manager=test_dataset_manager,
epoch=0,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
drop_last=False
)
# Initialize storage for this category
anomaly_maps_padded = {
'anomaly_geometric': [],
'anomaly_arithmetic': [],
'latent_discrepancy': [],
'image_discrepancy': []
}
all_json_outputs = []
# Process full image evaluation loop (save_queue and save_thread are shared across categories)
anomaly_maps_padded, all_json_outputs, existing_json_data, processed_image_paths = \
process_full_image_evaluation_loop(
test_loader=test_loader,
test_dataset_manager=test_dataset_manager,
device=self.device,
vae=self.vae,
model=self.model,
amp_dtype=self.amp_dtype,
diffusion=diffusion,
args=args,
category=category,
data_dir=args.data_dir,
calculate_anomaly_maps_batch_func=lambda x0, encoded, image_samples, latent_samples, crop_size: \
self.calculate_anomaly_maps_batch(
x0, encoded, image_samples, latent_samples, crop_size,
image_diff_clip_max=self.image_diff_clip_max,
latent_diff_clip_max=self.latent_diff_clip_max
),
save_plot_path=self.save_plot_path,
save_queue=self.save_queue,
queue_plot_save_batch_or_individual_func=queue_plot_save_func,
enable_resume=enable_resume,
json_output_path=json_output_path,
existing_json_data=existing_json_data,
processed_image_paths=processed_image_paths,
all_json_outputs=all_json_outputs,
anomaly_maps_padded=anomaly_maps_padded
)
# Stack anomaly maps
anomaly_maps = stack_anomaly_maps(anomaly_maps_padded)
# Store results for this category
all_category_anomaly_maps[category] = anomaly_maps
# Get masks if available (for metrics calculation)
# Note: In full_image_eval mode, masks might not be available
# This is a simplified version - can be extended if needed
all_category_masks[category] = None
all_category_labels[category] = None
self.log_message(f"Completed evaluation for category: {category}")
self.log_message(f"{'='*80}\n")
# Wait for all images to be saved (after all categories)
if self.save_plot_path and self.save_queue and not self._disable_file_saving:
self.wait_for_saves()
# Return results in a format compatible with standard evaluation
# Combine all categories if multiple
if len(categories) == 1:
return all_category_anomaly_maps[categories[0]], all_category_masks[categories[0]], all_category_labels[categories[0]]
else:
# Combine all categories
combined_anomaly_maps = {}
for key in ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy']:
combined_anomaly_maps[key] = np.concatenate([
all_category_anomaly_maps[cat][key] for cat in categories
if key in all_category_anomaly_maps[cat]
], axis=0)
return combined_anomaly_maps, None, None
def _setup_evaluation(self, args):
"""Setup dataset, dataloader, and diffusion for evaluation.
Extracted from evaluate() lines 970-1012 to make it composable and testable.
Args:
args: Configuration arguments
Returns:
Tuple[Dataset, DataLoader, Diffusion]: test_dataset, test_loader, diffusion
"""
test_dataset_manager = create_temporal_dataset_manager(
dataset_name=args.dataset,
mode='test',
object_class=args.object_category,
rootdir=args.data_dir,
image_size=args.image_size,
crop_size=args.crop_size,
anomaly_class=getattr(args, 'anomaly_class', None),
csv_split_file=getattr(args, 'csv_split_file', None)
)
test_loader = create_temporal_dataloader(
dataset_manager=test_dataset_manager,
epoch=0,
batch_size=args.batch_size,
shuffle=False, # No shuffle for validation
num_workers=args.num_workers,
drop_last=False
)
# Initialize diffusion
diffusion = create_diffusion(
f'ddim{args.reverse_steps}',
predict_deviation=True,
sigma_small=False,
predict_xstart=False,
diffusion_steps=10
)
return test_dataset_manager.current_dataset, test_loader, diffusion
def _run_inference_iter(self, args, test_dataset, test_loader, diffusion):
"""Iterator yielding fundamentals from inference.
Yields: (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths)
"""
from tqdm import tqdm
import time
for batch_idx, (x, mask, object_cls) in enumerate(tqdm(test_loader, desc="Detection")):
batch_iter_start = time.time()
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ Inference interrupted at batch {batch_idx}. Shutting down...")
break
data_load_time = time.time() - batch_iter_start
if self.debug_enabled:
self.log_message(f" [TIMING] Batch {batch_idx} data load: {data_load_time:.3f}s", level='debug')
x = x.to(self.device)
# Get image paths
batch_size = x.shape[0]
batch_image_paths = []
for b in range(batch_size):
idx = batch_idx * args.batch_size + b
if idx < len(test_dataset):
if hasattr(test_dataset, 'image_paths'):
batch_image_paths.append(test_dataset.image_paths[idx])
elif hasattr(test_dataset, 'data_list'):
batch_image_paths.append(test_dataset.data_list[idx])
elif hasattr(test_dataset, 'data_df'):
# Extract path from DataFrame (column 'image' contains the relative path)
batch_image_paths.append(test_dataset.data_df.iloc[idx]['image'])
else:
batch_image_paths.append(f"image_{idx:04d}")
# Prepare model kwargs with proper object_cls tensor formatting
# Model expects context tensor with shape [batch_size, 1]
if isinstance(object_cls, torch.Tensor):
object_cls_tensor = object_cls.to(self.device)
# Ensure shape is [batch_size, 1] for embedding
if len(object_cls_tensor.shape) == 1:
object_cls_tensor = object_cls_tensor.unsqueeze(1)
elif len(object_cls_tensor.shape) == 0:
# Scalar tensor, expand to [batch_size, 1]
object_cls_tensor = object_cls_tensor.unsqueeze(0).unsqueeze(0).expand(x.shape[0], 1)
else:
# If it's a single int or list, convert to tensor [batch_size, 1]
if isinstance(object_cls, (list, tuple)):
object_cls_tensor = torch.tensor(object_cls, device=self.device, dtype=torch.long)
else:
object_cls_tensor = torch.tensor([object_cls] * x.shape[0], device=self.device, dtype=torch.long)
# Ensure shape is [batch_size, 1]
if len(object_cls_tensor.shape) == 1:
object_cls_tensor = object_cls_tensor.unsqueeze(1)
# Ensure dtype is long for embedding indices
if object_cls_tensor.dtype != torch.long:
object_cls_tensor = object_cls_tensor.long()
# Prepare model kwargs (needed for both TTA and non-TTA)
model_kwargs = {'context': object_cls_tensor, 'mask': None}
# Check TTA mode - compute per-shift fundamentals for saving
if args.pad_px is not None:
# Compute TTA fundamentals per shift
tta_result = self._compute_tta_per_shift_fundamentals(
x,
model_kwargs=model_kwargs,
diffusion=diffusion,
reverse_steps=args.reverse_steps,
pad_px=args.pad_px,
img_enlarge_px=getattr(args, 'img_enlarge_px', 4),
stride=getattr(args, 'stride', 1),
directions=getattr(args, 'directions', ('h', 'v', 'diag')),
shift_method=getattr(args, 'shift_method', 'interpolate'),
tta_batch_size=getattr(args, 'tta_batch_size', 8),
eta=0.0,
crop_size=args.crop_size,
)
# Extract fundamentals and metadata
fundamentals_dict = tta_result['fundamentals']
tta_metadata = {
'num_shifts': tta_result['num_shifts'],
'shifts': tta_result['shifts'],
'directions': tta_result['directions'],
'pad_px': tta_result['pad_px'],
'stride': tta_result['stride'],
}
# Yield 6-tuple for TTA mode (will be saved by _save_fundamentals_iter)
yield (x, mask, object_cls, fundamentals_dict, batch_image_paths, tta_metadata)
continue
# Standard inference (no TTA)
with torch.no_grad():
encoded = self.vae.encode(x).latent_dist.mean.mul_(0.18215)
latent_samples = diffusion.ddim_deviation_sample_loop(
self.model,
x.shape,
noise=encoded,
clip_denoised=False,
denoised_fn=None,
model_kwargs=model_kwargs,
device=self.device,
progress=False,
eta=0.0,
)
image_samples = self.vae.decode(latent_samples / 0.18215).sample
x0 = self.vae.decode(encoded / 0.18215).sample
# Yield fundamentals (8-tuple)
yield (x, mask, object_cls, x0, encoded, image_samples, latent_samples, batch_image_paths)
def _save_fundamentals_iter(self, fundamental_iter, save_dir, args):
"""Save fundamentals as NPY files with single CSV manifest and pass through.
Args:
fundamental_iter: Iterator yielding either:
- TTA: (x, mask, object_cls, fundamentals_dict, paths, tta_metadata)
- Non-TTA: (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths)
save_dir: Directory to save NPY files and CSV manifest
args: Configuration arguments
Yields: Same as input (pass-through)
"""
import os
import numpy as np
import csv
os.makedirs(save_dir, exist_ok=True)
# Check if consolidated NPZ format is requested
use_consolidated_npz = getattr(args, 'use_consolidated_npz', False)
# Check for existing CSV to resume from checkpoint
csv_path = os.path.join(save_dir, 'fundamentals_manifest.csv')
resume_from_batch = -1
global_img_idx = 0 # Global image counter across all batches
if os.path.exists(csv_path):
# Read last batch_idx from existing CSV and count processed images
try:
with open(csv_path, 'r', newline='', encoding='utf-8') as f:
reader = csv.DictReader(f)
last_row = None
seen_images = set() # Track unique (batch_idx, image_idx_in_batch) pairs
for row in reader:
last_row = row
batch_idx = int(row['batch_idx'])
img_idx = int(row['image_idx_in_batch'])
seen_images.add((batch_idx, img_idx))
if last_row is not None:
resume_from_batch = int(last_row['batch_idx'])
global_img_idx = len(seen_images) # Count unique images processed
print(f"Resuming from batch {resume_from_batch + 1} (checkpoint found, {global_img_idx} images processed)")
except Exception as e:
print(f"Warning: Could not read checkpoint from {csv_path}: {e}")
print("Starting from scratch")
resume_from_batch = -1
global_img_idx = 0
# Open CSV manifest in appropriate mode
if resume_from_batch >= 0:
# Append mode - resume from checkpoint
csv_file = open(csv_path, 'a', newline='', encoding='utf-8')
csv_writer = csv.writer(csv_file)
else:
# Write mode - start fresh
csv_file = open(csv_path, 'w', newline='', encoding='utf-8')
csv_writer = csv.writer(csv_file)
# Write header
csv_writer.writerow([
'batch_idx', 'image_idx_in_batch', 'original_file_path', 'full_file_path',
'fundamental_type', 'is_tta', 'total_shifts', 'shift_idx',
'shift_type', 'shift_amount', 'npy_filename'
])
try:
for batch_idx, batch_data in enumerate(fundamental_iter):
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ Saving fundamentals interrupted at batch {batch_idx}. Shutting down...")
break
# Skip batches that are already processed (checkpoint resume)
if batch_idx <= resume_from_batch:
# Pass through without processing
if len(batch_data) == 6: # TTA mode
yield batch_data
else: # Non-TTA mode
yield batch_data
continue
# Detect TTA vs non-TTA by tuple length
if len(batch_data) == 6: # TTA mode
x, mask, object_cls, fundamentals_dict, paths, tta_metadata = batch_data
is_tta = True
else: # Non-TTA mode (8-tuple)
x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths = batch_data
is_tta = False
if is_tta:
# Save per-shift fundamentals
num_shifts = tta_metadata['num_shifts']
shifts = tta_metadata['shifts']
for img_idx_in_batch, path in enumerate(paths):
if use_consolidated_npz:
# CONSOLIDATED NPZ FORMAT: Save all fundamentals in one .npz file
original_image = x[img_idx_in_batch].cpu().numpy()
# Collect all shift data
consolidated_data = {
'original_image': original_image,
'num_shifts': num_shifts,
'shifts': np.array(shifts) # List of (dx, dy) tuples
}
# Add each fundamental type for all shifts
for fund_type in ['x0', 'encoded', 'image_samples', 'latent_samples']:
# Stack all shifts into single array [num_shifts, ...]
shift_arrays = []
for shift_idx in range(num_shifts):
fund_array = fundamentals_dict[fund_type][img_idx_in_batch][shift_idx]
if fund_array is not None:
shift_arrays.append(fund_array)
if shift_arrays:
consolidated_data[fund_type] = np.stack(shift_arrays, axis=0)
# Save single compressed .npz file
npz_filename = f"{global_img_idx:06d}_fundamentals.npz"
np.savez_compressed(
os.path.join(save_dir, npz_filename),
**consolidated_data
)
# Write single CSV row for consolidated file
csv_writer.writerow([
batch_idx,
img_idx_in_batch,
os.path.splitext(os.path.basename(path))[0], # original_file_path
path, # full_file_path
'consolidated', # fundamental_type
1, # is_tta
num_shifts,
-1, # shift_idx (-1 for consolidated)
'consolidated', # shift_type
0, # shift_amount
npz_filename
])
else:
# ORIGINAL MULTI-FILE FORMAT: Separate .npy per fundamental
# Save original image (once per image, not per shift)
original_image = x[img_idx_in_batch].cpu().numpy()
original_npy_filename = f"{global_img_idx:06d}_original_image.npy"
np.save(os.path.join(save_dir, original_npy_filename), original_image)
# Write CSV row for original image
csv_writer.writerow([
batch_idx,
img_idx_in_batch,
os.path.splitext(os.path.basename(path))[0], # original_file_path (without extension)
path, # full_file_path (with extension)
'original_image',
1, # is_tta
num_shifts,
-1, # shift_idx (-1 means original, not shifted)
'original', # shift_type
0, # shift_amount
original_npy_filename
])
for shift_idx, (dx, dy) in enumerate(shifts):
# Determine shift type and amount
shift_type, shift_amount = _get_shift_type_and_amount(dx, dy)
for fund_type in ['x0', 'encoded', 'image_samples', 'latent_samples']:
fund_array = fundamentals_dict[fund_type][img_idx_in_batch][shift_idx]
if fund_array is None:
continue # Skip invalid shifts
# Save NPY file
npy_filename = f"{global_img_idx:06d}_shift_{shift_idx:02d}_{fund_type}.npy"
np.save(os.path.join(save_dir, npy_filename), fund_array)
# Write CSV row
csv_writer.writerow([
batch_idx,
img_idx_in_batch,
os.path.splitext(os.path.basename(path))[0], # original_file_path (without extension)
path, # full_file_path (with extension)
fund_type,
1, # is_tta
num_shifts,
shift_idx,
shift_type,
shift_amount,
npy_filename
])
# Increment global image counter after processing each image
global_img_idx += 1
# Flush CSV to ensure incremental writes
csv_file.flush()
# Pass through (6-tuple)
yield (x, mask, object_cls, fundamentals_dict, paths, tta_metadata)
else:
# Non-TTA mode - save per-image files
# Check if fundamentals are None (happens when TTA is enabled but yields 8-tuple with None values)
if x0 is None:
# TTA mode with dummy 8-tuple - skip saving and just pass through
yield (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths)
continue
# Save per-image fundamentals for non-TTA mode
for img_idx_in_batch, path in enumerate(paths):
# Extract single image from batch
img_x = x[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy()
img_x0 = x0[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy()
img_encoded = encoded[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy()
img_samples = image_samples[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy()
img_latents = latent_samples[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy()
if use_consolidated_npz:
# CONSOLIDATED NPZ FORMAT: Save all fundamentals in one .npz file
consolidated_data = {
'original_image': img_x,
'x0': img_x0,
'encoded': img_encoded,
'image_samples': img_samples,
'latent_samples': img_latents
}
# Save single compressed .npz file
npz_filename = f"{global_img_idx:06d}_fundamentals.npz"
np.savez_compressed(
os.path.join(save_dir, npz_filename),
**consolidated_data
)
# Write single CSV row for consolidated file
csv_writer.writerow([
batch_idx,
img_idx_in_batch,
os.path.splitext(os.path.basename(path))[0], # original_file_path
path, # full_file_path
'consolidated', # fundamental_type
0, # is_tta
0, # total_shifts
-1, # shift_idx (-1 for consolidated)
'consolidated', # shift_type
0, # shift_amount
npz_filename
])
else:
# ORIGINAL MULTI-FILE FORMAT: Separate .npy per fundamental
# Save NPY files with global image index
np.save(os.path.join(save_dir, f"{global_img_idx:06d}_original_image.npy"), img_x)
np.save(os.path.join(save_dir, f"{global_img_idx:06d}_x0.npy"), img_x0)
np.save(os.path.join(save_dir, f"{global_img_idx:06d}_encoded.npy"), img_encoded)
np.save(os.path.join(save_dir, f"{global_img_idx:06d}_image_samples.npy"), img_samples)
np.save(os.path.join(save_dir, f"{global_img_idx:06d}_latent_samples.npy"), img_latents)
# Write CSV rows for non-TTA (one row per fundamental type per image)
for fund_type in ['original_image', 'x0', 'encoded', 'image_samples', 'latent_samples']:
npy_filename = f"{global_img_idx:06d}_{fund_type}.npy"
csv_writer.writerow([
batch_idx,
img_idx_in_batch,
os.path.splitext(os.path.basename(path))[0], # original_file_path (without extension)
path, # full_file_path (with extension)
fund_type,
0, # is_tta
0, # total_shifts
0, # shift_idx
'none', # shift_type
0, # shift_amount
npy_filename
])
# Increment global image counter after processing each image
global_img_idx += 1
# Flush CSV to ensure incremental writes
csv_file.flush()
# Pass through (8-tuple)
yield (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths)
finally:
# Always close CSV file
csv_file.close()
print(f"Saved fundamentals manifest: {csv_path}")
def _load_fundamentals_iter(self, save_dir, test_dataset, args):
"""Load fundamentals from NPY files, handling both TTA and non-TTA.
Args:
save_dir: Directory containing NPY files and CSV manifest
test_dataset: Dataset (used for getting masks)
args: Command-line arguments
Yields: Either:
- TTA: (None, mask, None, fundamentals_dict, paths, tta_metadata)
- Non-TTA: (None, mask, None, x0, encoded, image_samples, latent_samples, paths)
"""
import os
import csv
import numpy as np
import torch
from tqdm import tqdm
from collections import defaultdict
# Read CSV manifest
csv_path = os.path.join(save_dir, 'fundamentals_manifest.csv')
if not os.path.exists(csv_path):
raise FileNotFoundError(f"CSV manifest not found: {csv_path}")
# Parse CSV and group by batch_idx
batches = defaultdict(lambda: {
'is_tta': None,
'paths': [],
'rows': [] # Store all CSV rows for this batch
})
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
batch_idx = int(row['batch_idx'])
is_tta = int(row['is_tta']) == 1
# Set is_tta for this batch (should be consistent)
if batches[batch_idx]['is_tta'] is None:
batches[batch_idx]['is_tta'] = is_tta
# Track unique paths - FIX: use full_file_path instead of original_file_path
# to avoid collisions when different subdirectories have same filename
# (e.g., wood/test/color/000.png and wood/test/combined/000.png both map to "000")
path = row['full_file_path']
if path not in batches[batch_idx]['paths']:
batches[batch_idx]['paths'].append(path)
# Store row for later processing
batches[batch_idx]['rows'].append(row)
# Load fundamentals for each batch
for batch_idx in tqdm(sorted(batches.keys()), desc="Loading fundamentals"):
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ Loading fundamentals interrupted at batch {batch_idx}. Shutting down...")
break
batch_info = batches[batch_idx]
paths = batch_info['paths']
is_tta = batch_info['is_tta']
# Load ground truth masks from dataset
batch_size = len(paths)
masks_batch = []
# FIX: Use fixed batch size from args for index calculation
# Same bug as in _load_fundamentals_iter_optimized
fixed_batch_size = args.batch_size
for b in range(batch_size):
dataset_idx = batch_idx * fixed_batch_size + b
if dataset_idx < len(test_dataset):
_, mask, _ = test_dataset[dataset_idx]
masks_batch.append(mask)
# Skip this batch if we don't have ground truth masks
if not masks_batch:
if self.debug_enabled:
self.log_message(f" [SKIP] Batch {batch_idx}: no ground truth data available", level='debug')
continue
mask = torch.stack(masks_batch) if len(masks_batch) > 1 else masks_batch[0].unsqueeze(0)
# Filter paths to match available masks
num_masks = len(masks_batch)
if num_masks < len(paths):
paths = paths[:num_masks]
if is_tta:
# Load per-shift fundamentals from CSV rows
rows = batch_info['rows']
# Determine batch size and num_shifts (use num_masks not len(paths))
num_images = num_masks
num_shifts = max(int(row['total_shifts']) for row in rows)
# Initialize fundamentals storage
# Note: original_image is saved in manifest but not loaded here (only used for visualization)
fundamentals_dict = {
'x0': [[None for _ in range(num_shifts)] for _ in range(num_images)],
'encoded': [[None for _ in range(num_shifts)] for _ in range(num_images)],
'image_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)],
'latent_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)],
}
shifts = [] # Reconstruct shift list
shift_coords_seen = {} # Track (dx, dy) -> shift_idx mapping
# Load each file
for row in rows:
img_idx = int(row['image_idx_in_batch'])
shift_idx = int(row['shift_idx'])
fund_type = row['fundamental_type']
npy_filename = row['npy_filename']
# Skip images beyond available masks
if img_idx >= num_masks:
continue
# Skip original_image during loading (it's only used for visualization, not evaluation)
if fund_type == 'original_image':
continue
# Load NPY file
npy_path = os.path.join(save_dir, npy_filename)
fund_array = np.load(npy_path)
# Store in fundamentals dict
fundamentals_dict[fund_type][img_idx][shift_idx] = fund_array
# Reconstruct shifts list (only once per shift_idx)
if fund_type == 'x0' and img_idx == 0: # Use first image, first fundamental type
shift_type = row['shift_type']
shift_amount = int(row['shift_amount'])
# Convert back to (dx, dy)
if shift_type == 'none':
dx, dy = 0, 0
elif shift_type == 'h':
dx, dy = shift_amount, 0
elif shift_type == 'v':
dx, dy = 0, shift_amount
elif shift_type == 'diag':
# For diagonal, both dx and dy have same sign
# Positive shift_amount -> both positive, negative -> both negative
dx = dy = shift_amount
# Only add if we haven't seen this shift_idx yet
if shift_idx not in shift_coords_seen:
shift_coords_seen[shift_idx] = (dx, dy)
shifts.append((dx, dy))
# Sort shifts by their index to maintain order
shifts_sorted = [shift_coords_seen[i] for i in sorted(shift_coords_seen.keys())]
# Create TTA metadata
tta_metadata = {
'shifts': shifts_sorted,
'num_shifts': num_shifts,
'directions': None, # Not stored in CSV
'pad_px': None, # Not stored in CSV
'stride': None # Not stored in CSV
}
# Yield TTA format (6-tuple)
yield (None, mask, None, fundamentals_dict, paths, tta_metadata)
else:
# Non-TTA mode - load batch files
x0 = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_x0.npy")))
encoded = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_encoded.npy")))
image_samples = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_image_samples.npy")))
latent_samples = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_latent_samples.npy")))
# Yield non-TTA format (8-tuple)
yield (None, mask, None, x0, encoded, image_samples, latent_samples, paths)
# ============================================================================
# PHASE 1: I/O PARALLELIZATION - ParallelNPYLoader
# ============================================================================
class ParallelNPYLoader:
"""Parallel NPY file loader with memory mapping and prefetch.
Features:
- ThreadPoolExecutor for parallel file I/O
- Memory-mapped loading for large files
- LRU caching for frequently accessed files
- Rolling prefetch queue
"""
def __init__(self, save_dir, num_workers=8, use_mmap=True, cache_size_mb=200):
"""Initialize parallel NPY loader.
Args:
save_dir: Directory containing NPY files
num_workers: Number of parallel I/O workers
use_mmap: Use memory-mapped mode for large files
cache_size_mb: Maximum cache size in MB
"""
from concurrent.futures import ThreadPoolExecutor
self.save_dir = save_dir
self.num_workers = num_workers
self.use_mmap = use_mmap
self.cache_size_bytes = cache_size_mb * 1024 * 1024
self.executor = ThreadPoolExecutor(max_workers=num_workers)
self.cache = {} # filepath -> numpy array
self.cache_order = [] # LRU tracking
self.cache_memory = 0 # Current cache size in bytes
self.format_logged = False # Track if format has been logged
def load_npy_file(self, filepath):
"""Load single NPY file with memory mapping and caching.
Args:
filepath: Absolute path to NPY file
Returns:
numpy.ndarray: Loaded array
"""
import numpy as np
import os
# Check cache first
if filepath in self.cache:
# Move to end (most recently used)
self.cache_order.remove(filepath)
self.cache_order.append(filepath)
return self.cache[filepath]
# Load from disk
if not os.path.exists(filepath):
raise FileNotFoundError(f"NPY file not found: {filepath}")
if self.use_mmap:
arr = np.load(filepath, mmap_mode='r')
arr_size = arr.nbytes
# Cache small files (< 50MB) in memory
if arr_size < 50 * 1024 * 1024:
arr = np.array(arr) # Force load into memory
self._add_to_cache(filepath, arr, arr_size)
else:
arr = np.load(filepath)
self._add_to_cache(filepath, arr, arr.nbytes)
return arr
def load_npz_consolidated(self, filepath):
"""Load consolidated .npz file containing all fundamentals.
Args:
filepath: Absolute path to .npz file
Returns:
dict: Dictionary with keys like 'x0', 'encoded', 'image_samples', 'latent_samples', etc.
"""
import numpy as np
import os
# Check cache first
if filepath in self.cache:
self.cache_order.remove(filepath)
self.cache_order.append(filepath)
return self.cache[filepath]
# Load from disk
if not os.path.exists(filepath):
raise FileNotFoundError(f"NPZ file not found: {filepath}")
# Load NPZ file (with memory mapping if enabled)
if self.use_mmap:
data = np.load(filepath, mmap_mode='r')
else:
data = np.load(filepath)
# Extract all arrays into dictionary
consolidated_data = {key: np.array(data[key]) for key in data.files}
# Calculate total size
total_size = sum(arr.nbytes if isinstance(arr, np.ndarray) else 0
for arr in consolidated_data.values())
# Cache if small enough (< 50MB)
if total_size < 50 * 1024 * 1024:
self._add_to_cache(filepath, consolidated_data, total_size)
return consolidated_data
def _add_to_cache(self, filepath, arr, arr_size):
"""Add array or dict of arrays to cache with LRU eviction.
Args:
filepath: File path (cache key)
arr: Numpy array or dict of arrays to cache
arr_size: Size in bytes (pre-calculated)
"""
import gc
import numpy as np
# Evict LRU items if needed
evicted_count = 0
while (self.cache_memory + arr_size > self.cache_size_bytes and
len(self.cache) > 0):
lru_path = self.cache_order.pop(0)
evicted_item = self.cache.pop(lru_path)
# Calculate size based on type (array or dict of arrays)
if isinstance(evicted_item, dict):
evicted_size = sum(v.nbytes if isinstance(v, np.ndarray) else 0
for v in evicted_item.values())
else:
evicted_size = evicted_item.nbytes
self.cache_memory -= evicted_size
del evicted_item
evicted_count += 1
# Force garbage collection if we evicted items
if evicted_count > 0:
gc.collect()
# Add to cache
self.cache[filepath] = arr
self.cache_order.append(filepath)
self.cache_memory += arr_size
def load_batch_fundamentals_parallel(self, batch_info, save_dir):
"""Load all fundamentals for a batch in parallel.
Args:
batch_info: Dict with 'rows' key containing CSV rows
save_dir: Directory containing NPY files
Returns:
dict: Fundamentals dict with loaded arrays
"""
import os
import numpy as np
from concurrent.futures import as_completed
rows = batch_info['rows']
is_tta = batch_info['is_tta']
if is_tta:
# Check if using consolidated NPZ format
first_row_fund_type = rows[0]['fundamental_type']
is_consolidated = (first_row_fund_type == 'consolidated')
# Log format detection (once, on first batch)
if not self.format_logged:
self.format_logged = True
format_str = "Consolidated NPZ (1 file/image)" if is_consolidated else "Multi-file NPY (33 files/image)"
print(f" Detected format: {format_str}")
# Determine dimensions
num_images = len(batch_info['paths'])
num_shifts = max(int(row['total_shifts']) for row in rows)
# Initialize fundamentals storage
fundamentals_dict = {
'x0': [[None for _ in range(num_shifts)] for _ in range(num_images)],
'encoded': [[None for _ in range(num_shifts)] for _ in range(num_images)],
'image_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)],
'latent_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)],
}
# For consolidated format, store shifts from NPZ file
loaded_shifts = None
if is_consolidated:
# CONSOLIDATED NPZ FORMAT: Load one .npz per image
futures = {}
for row in rows:
if row['fundamental_type'] != 'consolidated':
continue
npz_filename = row['npy_filename'] # Actually .npz
npz_path = os.path.join(save_dir, npz_filename)
# Submit load job for consolidated file
future = self.executor.submit(self.load_npz_consolidated, npz_path)
futures[future] = row
# Collect results as they complete
for future in as_completed(futures):
row = futures[future]
try:
consolidated_data = future.result(timeout=30)
# Extract image index
img_idx = int(row['image_idx_in_batch'])
# FIX: Extract shifts from first NPZ file (all images have same shifts)
if loaded_shifts is None and 'shifts' in consolidated_data:
loaded_shifts = consolidated_data['shifts'] # shape: (num_shifts, 2)
# Populate fundamentals_dict from consolidated data
for fund_type in ['x0', 'encoded', 'image_samples', 'latent_samples']:
if fund_type in consolidated_data:
fund_arrays = consolidated_data[fund_type] # [num_shifts, ...]
# Assign each shift
for shift_idx in range(len(fund_arrays)):
if img_idx < len(fundamentals_dict[fund_type]):
fundamentals_dict[fund_type][img_idx][shift_idx] = fund_arrays[shift_idx]
except Exception as e:
print(f"Warning: Failed to load consolidated {row['npy_filename']}: {e}")
continue
# FIX: Store loaded_shifts in fundamentals_dict for later use
if loaded_shifts is not None:
fundamentals_dict['_shifts'] = loaded_shifts
else:
# ORIGINAL MULTI-FILE FORMAT: Load individual .npy files
# Submit all loads to thread pool
futures = {}
for row in rows:
fund_type = row['fundamental_type']
# Skip original_image (not used for evaluation)
if fund_type == 'original_image':
continue
npy_filename = row['npy_filename']
npy_path = os.path.join(save_dir, npy_filename)
# Submit load job
future = self.executor.submit(self.load_npy_file, npy_path)
futures[future] = row
# Collect results as they complete
for future in as_completed(futures):
row = futures[future]
try:
fund_array = future.result(timeout=30) # 30s timeout per file
# Store in fundamentals dict
img_idx = int(row['image_idx_in_batch'])
shift_idx = int(row['shift_idx'])
fund_type = row['fundamental_type']
# Bounds check: skip if img_idx is out of range
# (happens when cache was created with different dataset filters)
if img_idx >= len(fundamentals_dict[fund_type]):
print(f"Warning: Skipping {row['npy_filename']}: img_idx {img_idx} out of range (batch has {len(fundamentals_dict[fund_type])} images)")
continue
fundamentals_dict[fund_type][img_idx][shift_idx] = fund_array
except Exception as e:
# Log error but continue
print(f"Warning: Failed to load {row['npy_filename']}: {e}")
continue
return fundamentals_dict
else:
# Non-TTA mode
# Check if using consolidated NPZ format
first_row_fund_type = rows[0]['fundamental_type']
is_consolidated = (first_row_fund_type == 'consolidated')
# Log format detection (once, on first batch)
if not self.format_logged:
self.format_logged = True
format_str = "Consolidated NPZ (1 file/image)" if is_consolidated else "Multi-file NPY (4 files/image)"
print(f" Detected format: {format_str}")
if is_consolidated:
# CONSOLIDATED NPZ FORMAT: Load single .npz file
npz_filename = rows[0]['npy_filename']
npz_path = os.path.join(save_dir, npz_filename)
try:
consolidated_data = self.load_npz_consolidated(npz_path)
# Extract fundamentals
results = {
'x0': consolidated_data.get('x0'),
'encoded': consolidated_data.get('encoded'),
'image_samples': consolidated_data.get('image_samples'),
'latent_samples': consolidated_data.get('latent_samples')
}
return results
except Exception as e:
raise RuntimeError(f"Failed to load consolidated NPZ {npz_filename}: {e}")
else:
# ORIGINAL MULTI-FILE FORMAT: Load 4 batch files in parallel
batch_idx = int(rows[0]['batch_idx'])
file_types = ['x0', 'encoded', 'image_samples', 'latent_samples']
# Submit loads
futures = {}
for ftype in file_types:
filename = f"{batch_idx:06d}_{ftype}.npy"
filepath = os.path.join(save_dir, filename)
future = self.executor.submit(self.load_npy_file, filepath)
futures[future] = ftype
# Collect results
results = {}
for future in as_completed(futures):
ftype = futures[future]
try:
arr = future.result(timeout=30)
results[ftype] = arr
except Exception as e:
raise RuntimeError(f"Failed to load {ftype}: {e}")
return results
def shutdown(self):
"""Shutdown thread pool executor."""
self.executor.shutdown(wait=True)
self.cache.clear()
self.cache_order.clear()
self.cache_memory = 0
def _load_fundamentals_iter_optimized(self, save_dir, test_dataset, args):
"""Optimized version with parallel I/O and prefetching.
This is the Phase 1 optimization that uses ParallelNPYLoader for 2-3× speedup.
Args:
save_dir: Directory containing NPY files and CSV manifest
test_dataset: Dataset (used for getting masks)
args: Command-line arguments with optimization flags
Yields: Either:
- TTA: (None, mask, None, fundamentals_dict, paths, tta_metadata)
- Non-TTA: (None, mask, None, x0, encoded, image_samples, latent_samples, paths)
"""
import os
import csv
import numpy as np
import torch
from tqdm import tqdm
from collections import defaultdict
from concurrent.futures import as_completed
# Get I/O workers and prefetch settings from args
io_workers = getattr(args, 'io_workers', 16) # Increased default for SSDs
use_mmap = getattr(args, 'use_mmap', True)
prefetch_window = getattr(args, 'prefetch_window', 4) # Configurable prefetch
# Log I/O optimization settings for verification
print(f"\n{'='*70}")
print(f" NPY Loader Configuration (evaluate_only mode)")
print(f"{'='*70}")
print(f" I/O Workers: {io_workers} threads")
print(f" Prefetch Window: {prefetch_window} batches")
print(f" Memory Mapping: {'Enabled' if use_mmap else 'Disabled'}")
print(f" Cache Size: 200 MB")
print(f"{'='*70}\n")
# Initialize parallel loader
parallel_loader = self.ParallelNPYLoader(
save_dir=save_dir,
num_workers=io_workers,
use_mmap=use_mmap,
cache_size_mb=200
)
try:
# Read CSV manifest (same as original)
csv_path = os.path.join(save_dir, 'fundamentals_manifest.csv')
if not os.path.exists(csv_path):
raise FileNotFoundError(f"CSV manifest not found: {csv_path}")
# Parse CSV and group by batch_idx
batches = defaultdict(lambda: {
'is_tta': None,
'paths': [],
'rows': []
})
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
batch_idx = int(row['batch_idx'])
is_tta = int(row['is_tta']) == 1
if batches[batch_idx]['is_tta'] is None:
batches[batch_idx]['is_tta'] = is_tta
# FIX: use full_file_path instead of original_file_path
# to avoid collisions when different subdirectories have same filename
path = row['full_file_path']
if path not in batches[batch_idx]['paths']:
batches[batch_idx]['paths'].append(path)
batches[batch_idx]['rows'].append(row)
sorted_batch_ids = sorted(batches.keys())
# Prefetch system: Load batches in rolling window (configurable via --prefetch_window)
# prefetch_window already set from args above
prefetch_futures = {}
# Prefetch first batch(es)
for i in range(min(prefetch_window, len(sorted_batch_ids))):
batch_idx = sorted_batch_ids[i]
future = parallel_loader.executor.submit(
parallel_loader.load_batch_fundamentals_parallel,
batches[batch_idx],
save_dir
)
prefetch_futures[batch_idx] = future
# Process batches with rolling prefetch
for i, batch_idx in enumerate(tqdm(sorted_batch_ids, desc="Loading fundamentals (optimized)")):
# Check for shutdown
if self._shutdown_requested:
print(f"\n⚠️ Loading fundamentals interrupted at batch {batch_idx}. Shutting down...")
break
batch_info = batches[batch_idx]
paths = batch_info['paths']
is_tta = batch_info['is_tta']
# Get fundamentals (from prefetch or load now)
if batch_idx in prefetch_futures:
fundamentals_or_results = prefetch_futures[batch_idx].result(timeout=120)
del prefetch_futures[batch_idx]
else:
# Not prefetched, load now
fundamentals_or_results = parallel_loader.load_batch_fundamentals_parallel(
batch_info,
save_dir
)
# Prefetch next batch (rolling window)
next_idx = i + prefetch_window
if next_idx < len(sorted_batch_ids):
next_batch_idx = sorted_batch_ids[next_idx]
future = parallel_loader.executor.submit(
parallel_loader.load_batch_fundamentals_parallel,
batches[next_batch_idx],
save_dir
)
prefetch_futures[next_batch_idx] = future
# Load ground truth masks (same as original)
batch_size = len(paths)
masks_batch = []
# FIX: Use fixed batch size from args for index calculation
# The bug was using variable batch_size (len(paths)) which fails for partial batches
# E.g., if last batch has 9 items: dataset_idx = 7 * 9 = 63 (wrong!)
# Should be: dataset_idx = 7 * 10 = 70 (correct, using fixed batch_size=10)
fixed_batch_size = args.batch_size
for b in range(batch_size):
dataset_idx = batch_idx * fixed_batch_size + b
if dataset_idx < len(test_dataset):
_, mask, _ = test_dataset[dataset_idx]
masks_batch.append(mask)
# Skip this batch if we don't have ground truth masks
if not masks_batch:
if self.debug_enabled:
self.log_message(f" [SKIP] Batch {batch_idx}: no ground truth data available", level='debug')
continue
mask = torch.stack(masks_batch) if len(masks_batch) > 1 else masks_batch[0].unsqueeze(0)
# Filter fundamentals and paths to match available masks
num_masks = len(masks_batch)
if num_masks < len(paths):
paths = paths[:num_masks]
if is_tta:
# Fundamentals already loaded by parallel loader
fundamentals_dict = fundamentals_or_results
# DEBUG: Check fundamentals structure and verify shifts are different
if batch_idx == 0:
print(f"\n[DEBUG] Batch {batch_idx} fundamentals_dict from CACHE:")
for key in ['x0']:
print(f" {key}: type={type(fundamentals_dict[key])}, len={len(fundamentals_dict[key])}")
if len(fundamentals_dict[key]) > 0:
first = fundamentals_dict[key][0]
print(f" first element: type={type(first)}, len={len(first) if isinstance(first, list) else 'N/A'}")
if isinstance(first, list) and len(first) > 1:
if first[0] is not None and first[1] is not None:
import numpy as np
arr0 = first[0] if isinstance(first[0], np.ndarray) else first[0].cpu().numpy()
arr1 = first[1] if isinstance(first[1], np.ndarray) else first[1].cpu().numpy()
diff = np.abs(arr0 - arr1).sum()
print(f" Shift 0 vs Shift 1 |diff|: {diff:.6f} (CRITICAL: must be >0!)")
print(f" arr0: shape={arr0.shape}, mean={arr0.mean():.4f}")
print(f" arr1: shape={arr1.shape}, mean={arr1.mean():.4f}")
# Filter fundamentals to match available masks
if num_masks < len(fundamentals_dict['x0']):
for key in fundamentals_dict.keys():
fundamentals_dict[key] = fundamentals_dict[key][:num_masks]
# Reconstruct shifts list
rows = batch_info['rows']
num_shifts = max(int(row['total_shifts']) for row in rows)
# FIX: For consolidated format, use shifts loaded from NPZ file
if '_shifts' in fundamentals_dict:
# Consolidated format: shifts are in NPZ file
import numpy as np
shifts_array = fundamentals_dict['_shifts'] # shape: (num_shifts, 2)
shifts_sorted = [(int(shifts_array[i, 0]), int(shifts_array[i, 1])) for i in range(len(shifts_array))]
else:
# Multi-file format: reconstruct shifts from manifest
shifts = []
shift_coords_seen = {}
for row in rows:
img_idx = int(row['image_idx_in_batch'])
shift_idx = int(row['shift_idx'])
fund_type = row['fundamental_type']
# Reconstruct shifts list (only once per shift_idx)
if fund_type == 'x0' and img_idx == 0:
shift_type = row['shift_type']
shift_amount = int(row['shift_amount'])
# Convert back to (dx, dy)
if shift_type == 'none':
dx, dy = 0, 0
elif shift_type == 'h':
dx, dy = shift_amount, 0
elif shift_type == 'v':
dx, dy = 0, shift_amount
elif shift_type == 'diag':
dx = dy = shift_amount
if shift_idx not in shift_coords_seen:
shift_coords_seen[shift_idx] = (dx, dy)
shifts.append((dx, dy))
# Sort shifts by index
shifts_sorted = [shift_coords_seen[i] for i in sorted(shift_coords_seen.keys())]
# DEBUG: Check if shifts were reconstructed
if batch_idx == 0:
print(f"[DEBUG] Reconstructed shifts: {shifts_sorted}")
print(f"[DEBUG] num_shifts: {num_shifts}")
# Create TTA metadata
tta_metadata = {
'shifts': shifts_sorted,
'num_shifts': num_shifts,
'directions': None,
'pad_px': None,
'stride': None
}
# Yield TTA format
yield (None, mask, None, fundamentals_dict, paths, tta_metadata)
else:
# Non-TTA: Convert to tensors
results = fundamentals_or_results
x0 = torch.from_numpy(results['x0'])
encoded = torch.from_numpy(results['encoded'])
image_samples = torch.from_numpy(results['image_samples'])
latent_samples = torch.from_numpy(results['latent_samples'])
# Yield non-TTA format
yield (None, mask, None, x0, encoded, image_samples, latent_samples, paths)
finally:
# Cleanup
parallel_loader.shutdown()
def _load_fundamentals_iter_dispatcher(self, save_dir, test_dataset, args):
"""Dispatcher for _load_fundamentals_iter with legacy mode support.
Args:
save_dir: Directory containing NPY files
test_dataset: Dataset for masks
args: Command-line arguments
Yields:
Batch data (format depends on TTA mode)
"""
# Check if parallel I/O is enabled
legacy_mode = getattr(args, 'legacy_mode', False)
parallel_io_enabled = getattr(args, 'parallel_io', True)
if legacy_mode or not parallel_io_enabled:
# Use original sequential implementation
yield from self._load_fundamentals_iter(save_dir, test_dataset, args)
else:
# Use optimized parallel implementation (Phase 1)
yield from self._load_fundamentals_iter_optimized(save_dir, test_dataset, args)
def _fuse_array(self, stack, fuse_method, axis=-1):
"""Efficiently fuse array using partition-based methods (handles NaN).
Args:
stack: np.ndarray to fuse
fuse_method: 'mean', 'median', 'lowest', 'pct25', 'pct75'
axis: Axis to fuse along (default: -1, last dimension)
Returns:
Fused array
"""
if fuse_method == 'mean':
return np.nanmean(stack, axis=axis)
elif fuse_method == 'lowest':
return np.nanmin(stack, axis=axis)
# For partition-based methods, move target axis to the end for easier indexing
# This allows us to use [..., k] pattern like the existing optimized code
if axis != -1 and axis != stack.ndim - 1:
# Move axis to end
stack = np.moveaxis(stack, axis, -1)
moved_axis = True
else:
moved_axis = False
# Handle NaN: replace with inf for partition, then restore
nan_mask = np.isnan(stack)
stack_clean = np.where(nan_mask, np.inf, stack)
n_valid = stack.shape[-1]
if fuse_method == 'median':
# Optimize: use partition for median (much faster than full sort)
if n_valid % 2 == 1:
# Odd number: median is at index (n-1)//2
k = (n_valid - 1) // 2
averaged = np.partition(stack_clean, k, axis=-1)[..., k]
else:
# Even number: median is average of two middle elements
k1 = n_valid // 2 - 1
k2 = n_valid // 2
part = np.partition(stack_clean, [k1, k2], axis=-1)
averaged = (part[..., k1] + part[..., k2]) / 2.0
elif fuse_method == 'pct25':
# Optimize: use partition for 25th percentile (much faster than full sort)
k = int(np.ceil(n_valid * 0.25)) - 1
k = max(0, min(k, n_valid - 1)) # Clamp to valid range
averaged = np.partition(stack_clean, k, axis=-1)[..., k]
elif fuse_method == 'pct75':
# Optimize: use partition for 75th percentile
k = int(np.ceil(n_valid * 0.75)) - 1
k = max(0, min(k, n_valid - 1)) # Clamp to valid range
averaged = np.partition(stack_clean, k, axis=-1)[..., k]
else:
raise ValueError(f"Unknown fuse_method: {fuse_method}")
# Restore NaN: if all values were NaN, result should be NaN
all_nan_mask = nan_mask.all(axis=-1)
averaged[all_nan_mask] = np.nan
averaged[averaged == np.inf] = np.nan
# Move axis back if we moved it
if moved_axis:
averaged = np.moveaxis(averaged, -1, axis)
return averaged
def _fuse_fundamentals(self, fundamentals_list, fuse_method):
"""Fuse per-shift fundamentals using specified method.
Args:
fundamentals_list: List[List[np.ndarray]] - [img_idx][shift_idx] = [1, C, H, W]
fuse_method: 'mean', 'median', 'lowest', 'pct25', 'pct75'
Returns:
np.ndarray: [B, C, H, W] fused fundamentals
"""
import numpy as np
fused = []
for img_fundamentals in fundamentals_list: # Each image
# Filter out None values and stack: [num_valid_shifts, C, H, W]
valid_shifts = [f.squeeze(0) for f in img_fundamentals if f is not None]
if not valid_shifts:
raise ValueError("No valid shifts found for image")
stack = np.stack(valid_shifts, axis=0) # [num_shifts, C, H, W]
# Fuse across shift dimension (axis=0) using optimized partition method
fused_img = self._fuse_array(stack, fuse_method, axis=0)
fused.append(fused_img)
return np.stack(fused, axis=0) # [B, C, H, W]
def _compute_tta_batched(self, fundamentals_dict, B, num_shifts, shifts, img_size,
aug_size, offset, shift_method, args):
"""Phase 2: GPU-batched TTA processing (5-8× faster than sequential).
Batches all (B × num_shifts) shift operations into a single GPU call.
Returns:
tuple: (anomaly_maps_batch, individual_shifts)
"""
import numpy as np
import torch
import torch.nn.functional as F
keys = ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy']
fuse_method = getattr(args, 'fuse_method', 'mean')
# Collect all valid fundamentals
batch_fundamentals = [] # List of (img_idx, shift_idx, position, fundamentals)
for shift_idx, (dx, dy) in enumerate(shifts):
# Calculate position on canvas
y0 = offset + dy
x0_pos = offset + dx
y1 = y0 + img_size
x1 = x0_pos + img_size
# Skip invalid crops
if not (0 <= y0 < y1 <= aug_size and 0 <= x0_pos < x1 <= aug_size):
continue
for img_idx in range(B):
# Get fundamentals from fundamentals_dict
x0_fund = fundamentals_dict['x0'][img_idx][shift_idx]
encoded_fund = fundamentals_dict['encoded'][img_idx][shift_idx]
image_samples_fund = fundamentals_dict['image_samples'][img_idx][shift_idx]
latent_samples_fund = fundamentals_dict['latent_samples'][img_idx][shift_idx]
# Skip if fundamentals are None
if x0_fund is None or encoded_fund is None:
continue
batch_fundamentals.append({
'img_idx': img_idx,
'shift_idx': shift_idx,
'position': (x0_pos, y0, x1, y1),
'shift': (dx, dy),
'x0': x0_fund,
'encoded': encoded_fund,
'image_samples': image_samples_fund,
'latent_samples': latent_samples_fund
})
if not batch_fundamentals:
# No valid fundamentals - return empty results
anomaly_maps_batch = {k: np.zeros((B, img_size, img_size), dtype=np.float32) for k in keys}
individual_shifts = [[] for _ in range(B)]
return anomaly_maps_batch, individual_shifts
# Stack all fundamentals into batched tensors [N, C, H, W]
x0_batch = torch.from_numpy(np.concatenate([f['x0'] for f in batch_fundamentals], axis=0)).to(self.device)
encoded_batch = torch.from_numpy(np.concatenate([f['encoded'] for f in batch_fundamentals], axis=0)).to(self.device)
image_samples_batch = torch.from_numpy(np.concatenate([f['image_samples'] for f in batch_fundamentals], axis=0)).to(self.device)
latent_samples_batch = torch.from_numpy(np.concatenate([f['latent_samples'] for f in batch_fundamentals], axis=0)).to(self.device)
# Check if GPU ops are enabled (Phase 3)
gpu_ops_enabled = getattr(args, 'gpu_ops', True)
if gpu_ops_enabled:
# Phase 3: GPU-accelerated path (1.5-2× faster)
anomaly_batch = self.calculate_anomaly_maps_batch_gpu(
x0_batch, encoded_batch, image_samples_batch, latent_samples_batch,
crop_size=img_size,
image_diff_clip_max=self.image_diff_clip_max,
latent_diff_clip_max=self.latent_diff_clip_max
)
# Keep canvas on GPU as torch tensors
all_maps_gpu = {k: torch.full((B, num_shifts, aug_size, aug_size), float('nan'),
device=self.device, dtype=torch.float32) for k in keys}
individual_shifts = [[] for _ in range(B)]
# Track which (img_idx, shift_idx) combinations are valid
valid_shifts_tracker = [[False] * num_shifts for _ in range(B)]
for idx, fund_info in enumerate(batch_fundamentals):
img_idx = fund_info['img_idx']
shift_idx = fund_info['shift_idx']
x0_pos, y0, x1, y1 = fund_info['position']
dx, dy = fund_info['shift']
# Place on GPU canvas (in-place GPU operation)
for key in keys:
all_maps_gpu[key][img_idx, shift_idx, y0:y1, x0_pos:x1] = anomaly_batch[key][idx]
# Mark this shift as valid
valid_shifts_tracker[img_idx][shift_idx] = True
# Store individual shift data
crop_original = torch.from_numpy(fund_info['x0']).to(self.device)
crop_reconstruction = torch.from_numpy(fund_info['image_samples']).to(self.device)
individual_shifts[img_idx].append({
'shift': (dx, dy),
'position': (x0_pos, y0, x1, y1),
'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][idx].cpu().numpy(),
'crop_original': crop_original.detach().cpu(),
'crop_reconstruction': crop_reconstruction.detach().cpu(),
'shift_method': shift_method
})
# Fuse shifts on GPU
all_averaged_maps_gpu = {}
for key in keys:
fused_per_image = []
for img_idx in range(B):
# Get valid shift indices for this image
shift_maps = all_maps_gpu[key][img_idx] # [num_shifts, H, W]
valid_indices = [i for i, is_valid in enumerate(valid_shifts_tracker[img_idx]) if is_valid]
if not valid_indices:
fused_per_image.append(torch.zeros((aug_size, aug_size), device=self.device))
continue
valid_shifts = shift_maps[valid_indices] # [num_valid, H, W]
# Fuse on GPU using torch operations
if fuse_method == 'mean':
fused = torch.nanmean(valid_shifts, dim=0)
elif fuse_method == 'median':
fused = torch.nanmedian(valid_shifts, dim=0).values
elif fuse_method == 'pct25':
fused = torch.nanquantile(valid_shifts, 0.25, dim=0)
else:
fused = torch.nanmean(valid_shifts, dim=0) # fallback
fused_per_image.append(fused)
all_averaged_maps_gpu[key] = torch.stack(fused_per_image) # [B, H, W]
# Resize on GPU and convert to numpy only at the end
anomaly_maps_batch = {}
if shift_method == 'interpolate':
for key in keys:
resized = F.interpolate(
all_averaged_maps_gpu[key].unsqueeze(1), # [B, 1, H, W]
size=(img_size, img_size),
mode='bilinear',
align_corners=False
).squeeze(1) # [B, H, W]
anomaly_maps_batch[key] = resized.cpu().numpy()
else:
for key in keys:
cropped = all_averaged_maps_gpu[key][:, offset:offset+img_size, offset:offset+img_size]
anomaly_maps_batch[key] = cropped.cpu().numpy()
else:
# Legacy CPU path
anomaly_batch = self.calculate_anomaly_maps_batch(
x0_batch, encoded_batch, image_samples_batch, latent_samples_batch,
crop_size=img_size,
image_diff_clip_max=self.image_diff_clip_max,
latent_diff_clip_max=self.latent_diff_clip_max
)
# Place results on CPU canvas
all_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in keys}
individual_shifts = [[] for _ in range(B)]
for idx, fund_info in enumerate(batch_fundamentals):
img_idx = fund_info['img_idx']
shift_idx = fund_info['shift_idx']
x0_pos, y0, x1, y1 = fund_info['position']
dx, dy = fund_info['shift']
# Place on canvas
for key in keys:
canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32)
canvas[y0:y1, x0_pos:x1] = anomaly_batch[key][idx]
all_maps[key][img_idx][shift_idx] = canvas
# Store individual shift data
crop_original = torch.from_numpy(fund_info['x0']).to(self.device)
crop_reconstruction = torch.from_numpy(fund_info['image_samples']).to(self.device)
individual_shifts[img_idx].append({
'shift': (dx, dy),
'position': (x0_pos, y0, x1, y1),
'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][idx].copy(),
'crop_original': crop_original.detach().cpu(),
'crop_reconstruction': crop_reconstruction.detach().cpu(),
'shift_method': shift_method
})
# Fuse shifts
all_averaged_maps = {}
for key in keys:
all_averaged_maps[key] = []
for img_idx in range(B):
valid_maps = [m for m in all_maps[key][img_idx] if m is not None]
if not valid_maps:
all_averaged_maps[key].append(np.zeros((aug_size, aug_size), dtype=np.float32))
continue
stack = np.stack(valid_maps, axis=-1)
averaged = self._fuse_array(stack, fuse_method, axis=-1)
all_averaged_maps[key].append(averaged)
# Resize to original size
if shift_method == 'interpolate':
for key in keys:
stacked = np.stack(all_averaged_maps[key], axis=0)
tensor = torch.from_numpy(stacked).unsqueeze(1).to(self.device)
resized = F.interpolate(tensor, size=(img_size, img_size), mode='bilinear', align_corners=False)
anomaly_maps_batch[key] = resized.squeeze(1).cpu().numpy()
else:
for key in keys:
stacked = np.stack(all_averaged_maps[key], axis=0)
anomaly_maps_batch[key] = stacked[:, offset:offset+img_size, offset:offset+img_size]
return anomaly_maps_batch, individual_shifts
def _compute_tta_sequential(self, fundamentals_dict, B, num_shifts, shifts, img_size,
aug_size, offset, shift_method, args):
"""Legacy sequential TTA processing (kept for backward compatibility)."""
import numpy as np
import torch
import torch.nn.functional as F
keys = ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy']
all_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in keys}
individual_shifts = [[] for _ in range(B)]
# Process each shift sequentially (81 separate GPU calls)
for shift_idx, (dx, dy) in enumerate(shifts):
y0 = offset + dy
x0_pos = offset + dx
y1 = y0 + img_size
x1 = x0_pos + img_size
if not (0 <= y0 < y1 <= aug_size and 0 <= x0_pos < x1 <= aug_size):
continue
for img_idx in range(B):
x0_fund = fundamentals_dict['x0'][img_idx][shift_idx]
encoded_fund = fundamentals_dict['encoded'][img_idx][shift_idx]
image_samples_fund = fundamentals_dict['image_samples'][img_idx][shift_idx]
latent_samples_fund = fundamentals_dict['latent_samples'][img_idx][shift_idx]
if x0_fund is None or encoded_fund is None:
continue
x0_tensor = torch.from_numpy(x0_fund).to(self.device)
encoded_tensor = torch.from_numpy(encoded_fund).to(self.device)
image_samples_tensor = torch.from_numpy(image_samples_fund).to(self.device)
latent_samples_tensor = torch.from_numpy(latent_samples_fund).to(self.device)
anomaly_batch = self.calculate_anomaly_maps_batch(
x0_tensor, encoded_tensor, image_samples_tensor, latent_samples_tensor,
crop_size=img_size,
image_diff_clip_max=self.image_diff_clip_max,
latent_diff_clip_max=self.latent_diff_clip_max
)
for key in keys:
canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32)
canvas[y0:y1, x0_pos:x1] = anomaly_batch[key][0]
all_maps[key][img_idx][shift_idx] = canvas
crop_original = x0_tensor.detach().cpu()
crop_reconstruction = image_samples_tensor.detach().cpu()
individual_shifts[img_idx].append({
'shift': (dx, dy),
'position': (x0_pos, y0, x1, y1),
'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][0].copy(),
'crop_original': crop_original,
'crop_reconstruction': crop_reconstruction,
'shift_method': shift_method
})
# Fuse shifts
fuse_method = getattr(args, 'fuse_method', 'mean')
all_averaged_maps = {}
for key in keys:
all_averaged_maps[key] = []
for img_idx in range(B):
valid_maps = [m for m in all_maps[key][img_idx] if m is not None]
if not valid_maps:
all_averaged_maps[key].append(np.zeros((aug_size, aug_size), dtype=np.float32))
continue
stack = np.stack(valid_maps, axis=-1)
averaged = self._fuse_array(stack, fuse_method, axis=-1)
all_averaged_maps[key].append(averaged)
# Resize
anomaly_maps_batch = {}
if shift_method == 'interpolate':
for key in keys:
stacked = np.stack(all_averaged_maps[key], axis=0)
tensor = torch.from_numpy(stacked).unsqueeze(1).to(self.device)
resized = F.interpolate(tensor, size=(img_size, img_size), mode='bilinear', align_corners=False)
anomaly_maps_batch[key] = resized.squeeze(1).cpu().numpy()
else:
for key in keys:
stacked = np.stack(all_averaged_maps[key], axis=0)
anomaly_maps_batch[key] = stacked[:, offset:offset+img_size, offset:offset+img_size]
return anomaly_maps_batch, individual_shifts
def _compute_anomaly_maps_iter(self, fundamental_iter, args, diffusion=None):
"""Compute anomaly maps from fundamentals.
Consumes:
- 8-tuple (non-TTA): (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths)
- 6-tuple (TTA with precomputed fundamentals): (x, mask, object_cls, fundamentals_dict, paths, tta_metadata)
Yields: (anomaly_maps_dict, mask, paths, individual_shifts, orig_imgs)
"""
for batch_data in fundamental_iter:
# Detect tuple format (6-tuple for TTA with fundamentals, 8-tuple for non-TTA or TTA without fundamentals)
if len(batch_data) == 6:
# TTA mode with precomputed fundamentals
x, mask, object_cls, fundamentals_dict, paths, tta_metadata = batch_data
has_precomputed_fundamentals = True
else:
# Non-TTA mode or TTA without precomputed fundamentals
x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths = batch_data
has_precomputed_fundamentals = False
tta_start = time.time()
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ Anomaly map computation interrupted. Shutting down...")
break
individual_shifts = None
orig_imgs = None
# Check if TTA is enabled
if args.pad_px is not None and has_precomputed_fundamentals:
# TTA with precomputed fundamentals - compute anomaly maps from cached fundamentals
orig_imgs = x # May be None when loading from cache
# Get batch size from fundamentals_dict structure (since x may be None)
B = len(fundamentals_dict['x0'])
# Get image size from first fundamental (all should have same size)
first_fund = None
for img_idx in range(B):
for shift_idx in range(tta_metadata['num_shifts']):
if fundamentals_dict['x0'][img_idx][shift_idx] is not None:
first_fund = fundamentals_dict['x0'][img_idx][shift_idx]
break
if first_fund is not None:
break
if first_fund is None:
raise ValueError("No valid fundamentals found in batch")
# Get image size from fundamental shape [1, C, H, W]
# Fundamentals are saved with batch dimension, so shape[2] is height
img_size = first_fund.shape[2]
# Get metadata from tta_metadata
num_shifts = tta_metadata['num_shifts']
shifts = tta_metadata['shifts']
# Determine shift_method and img_enlarge_px from args
# Use pad_px from args (not tta_metadata which may be None)
shift_method = getattr(args, 'shift_method', 'interpolate')
img_enlarge_px = getattr(args, 'img_enlarge_px', 4)
pad_px_val = args.pad_px
# Calculate aug_size based on shift_method
if shift_method == 'interpolate':
aug_size = img_size + 2 * img_enlarge_px
offset = img_enlarge_px
elif shift_method == 'mirror':
aug_size = img_size + 2 * pad_px_val
offset = pad_px_val
else:
raise ValueError(f"Unknown shift_method: {shift_method}")
# Check if batched TTA is enabled (Phase 2 optimization)
batched_tta_enabled = getattr(args, 'batched_tta', True)
if batched_tta_enabled:
# Phase 2: GPU-batched TTA processing (5-8× faster)
anomaly_maps_batch, individual_shifts = self._compute_tta_batched(
fundamentals_dict, B, num_shifts, shifts, img_size, aug_size,
offset, shift_method, args
)
else:
# Legacy: Sequential TTA processing
anomaly_maps_batch, individual_shifts = self._compute_tta_sequential(
fundamentals_dict, B, num_shifts, shifts, img_size, aug_size,
offset, shift_method, args
)
elif args.pad_px is not None:
# Use TTA path with _anomaly_shift_avg_combined_batch
# Prepare model kwargs
if isinstance(object_cls, torch.Tensor):
object_cls_tensor = object_cls.to(self.device)
if len(object_cls_tensor.shape) == 1:
object_cls_tensor = object_cls_tensor.unsqueeze(1)
else:
if isinstance(object_cls, (list, tuple)):
object_cls_tensor = torch.tensor(object_cls, device=self.device, dtype=torch.long)
else:
object_cls_tensor = torch.tensor([object_cls] * x.shape[0], device=self.device, dtype=torch.long)
if len(object_cls_tensor.shape) == 1:
object_cls_tensor = object_cls_tensor.unsqueeze(1)
if object_cls_tensor.dtype != torch.long:
object_cls_tensor = object_cls_tensor.long()
model_kwargs = {'context': object_cls_tensor, 'mask': None}
# Store original images for plot_shift_analysis
orig_imgs = x
# Use TTA method
tta_compute_start = time.time()
anomaly_maps_batch, individual_shifts = self._anomaly_shift_avg_combined_batch(
x,
model_kwargs=model_kwargs,
diffusion=diffusion,
reverse_steps=args.reverse_steps,
pad_px=args.pad_px,
img_enlarge_px=getattr(args, 'img_enlarge_px', 4),
stride=getattr(args, 'stride', 1),
directions=getattr(args, 'directions', ('h', 'v', 'diag')),
shift_method=getattr(args, 'shift_method', 'interpolate'),
fuse_method=getattr(args, 'fuse_method', 'mean'),
tta_batch_size=getattr(args, 'tta_batch_size', 8),
eta=0.0,
crop_size=args.crop_size,
)
tta_compute_time = time.time() - tta_compute_start
if self.debug_enabled:
self.log_message(f" [TIMING] TTA computation total: {tta_compute_time:.3f}s", level='debug')
else:
# Standard path: compute anomaly maps using existing method
standard_start = time.time()
anomaly_maps_batch = self.calculate_anomaly_maps_batch(
x0, encoded, image_samples, latent_samples, args.crop_size,
image_diff_clip_max=self.image_diff_clip_max,
latent_diff_clip_max=self.latent_diff_clip_max
)
standard_time = time.time() - standard_start
if self.debug_enabled:
self.log_message(f" [TIMING] Standard anomaly calc: {standard_time:.3f}s", level='debug')
tta_total_time = time.time() - tta_start
if self.debug_enabled:
self.log_message(f" [TIMING] _compute_anomaly_maps_iter total: {tta_total_time:.3f}s", level='debug')
# Yield derived outputs
yield (anomaly_maps_batch, mask, paths, individual_shifts, orig_imgs)
def _aggregate_and_compute_metrics(self, anomaly_map_iter, args):
"""Aggregate anomaly maps and compute metrics.
Consumes: (anomaly_maps_dict, mask, paths, individual_shifts, orig_imgs)
Returns: (all_anomaly_maps, all_masks, all_labels)
"""
from collections import defaultdict
from tqdm import tqdm
import numpy as np
# Initialize storage
all_anomaly_maps = defaultdict(list)
all_masks = []
all_labels = []
all_image_paths = []
batch_idx = 0
# Accumulate
for (anomaly_maps_batch, mask, paths, individual_shifts, orig_imgs) in tqdm(anomaly_map_iter, desc="Evaluating"):
# Check for shutdown request
if self._shutdown_requested:
print(f"\n⚠️ Evaluation interrupted at batch {batch_idx}. Shutting down...")
break
for key in anomaly_maps_batch.keys():
all_anomaly_maps[key].append(anomaly_maps_batch[key])
all_masks.append(mask)
all_labels.append(mask) # Image-level labels
all_image_paths.extend(paths)
# Plot shift analysis if TTA is enabled (similar to decodiff_evaluator.py)
if args.pad_px is not None and individual_shifts is not None and orig_imgs is not None and self.save_plot_path:
try:
# Reorganize individual_shifts from per-image structure to per-shift structure
# Current: individual_shifts[img_idx] = list of shift dicts
# Expected: individual_shifts = list of shift dicts, each with arithmetic_combined [B, H, W]
if len(individual_shifts) > 0 and len(individual_shifts[0]) > 0:
batch_size = len(individual_shifts)
num_shifts = len(individual_shifts[0])
# Reorganize: group by shift instead of by image
shifts_for_plot = []
for shift_idx in range(num_shifts):
# Collect anomaly maps for all images for this shift
arithmetic_maps = []
crop_originals = []
crop_reconstructions = []
# Get shift info from first image (all images have same shifts)
first_shift_data = individual_shifts[0][shift_idx]
shift = first_shift_data['shift']
position = first_shift_data['position']
shift_method = first_shift_data.get('shift_method', 'interpolate')
# Collect data from all images for this shift
for img_idx in range(batch_size):
if shift_idx < len(individual_shifts[img_idx]):
shift_data = individual_shifts[img_idx][shift_idx]
# arithmetic_combined is [H, W] for this image+shift
arithmetic_maps.append(shift_data['arithmetic_combined'])
crop_originals.append(shift_data.get('crop_original'))
crop_reconstructions.append(shift_data.get('crop_reconstruction'))
# Stack to create [B, H, W]
if arithmetic_maps:
arithmetic_combined_batch = np.stack(arithmetic_maps, axis=0) # [B, H, W]
shifts_for_plot.append({
'shift': shift,
'position': position,
'arithmetic_combined': arithmetic_combined_batch, # [B, H, W]
'crop_original': crop_originals[0] if crop_originals else None, # Use first for display
'crop_reconstruction': crop_reconstructions[0] if crop_reconstructions else None,
'shift_method': shift_method
})
# Get averaged anomaly map for entire batch
averaged_anomaly_map = anomaly_maps_batch['anomaly_arithmetic']
# Get fuse_method from args for display in plot title
fuse_method = getattr(args, 'fuse_method', 'mean')
if not self._disable_file_saving:
# Queue plot saving for background thread (non-blocking)
if self.save_queue:
# Convert tensors to numpy/cpu before queuing to avoid GPU memory issues
orig_imgs_cpu = orig_imgs.detach().cpu().numpy() if isinstance(orig_imgs, torch.Tensor) else orig_imgs
averaged_anomaly_map_cpu = averaged_anomaly_map.copy() if isinstance(averaged_anomaly_map, np.ndarray) else averaged_anomaly_map.detach().cpu().numpy()
# Deep copy shifts_for_plot to avoid issues with shared references
shifts_for_plot_cpu = copy.deepcopy(shifts_for_plot)
# Convert any tensors in shift data to numpy
for shift_data in shifts_for_plot_cpu:
if 'crop_original' in shift_data and shift_data['crop_original'] is not None:
if isinstance(shift_data['crop_original'], torch.Tensor):
shift_data['crop_original'] = shift_data['crop_original'].detach().cpu().numpy()
if 'crop_reconstruction' in shift_data and shift_data['crop_reconstruction'] is not None:
if isinstance(shift_data['crop_reconstruction'], torch.Tensor):
shift_data['crop_reconstruction'] = shift_data['crop_reconstruction'].detach().cpu().numpy()
self.save_queue.put({
'kind': 'plot',
'individual_shifts': shifts_for_plot_cpu,
'averaged_anomaly_map': averaged_anomaly_map_cpu,
'batch_idx': batch_idx,
'orig_imgs': orig_imgs_cpu,
'save_path': self.save_plot_path,
'max_shifts_to_show': 16,
'fuse_method': fuse_method
})
else:
# Fallback to synchronous if queue not available
PR.plot_shift_analysis(
individual_shifts=shifts_for_plot,
averaged_anomaly_map=averaged_anomaly_map,
batch_idx=batch_idx,
orig_imgs=orig_imgs,
save_path=self.save_plot_path,
max_shifts_to_show=16,
fuse_method=fuse_method
)
else:
print(f" [SKIP] Shift analysis plot (file saving disabled)")
except Exception as e:
print(f"Warning: Failed to plot shift analysis: {e}")
import traceback
traceback.print_exc()
batch_idx += 1
# Concatenate
for key in all_anomaly_maps.keys():
all_anomaly_maps[key] = np.concatenate(all_anomaly_maps[key], axis=0)
all_masks = np.concatenate(all_masks, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Calculate and print metrics
for key in all_anomaly_maps.keys():
metrics = self.calculate_metrics(all_masks, all_anomaly_maps[key])
self.log_message(f"\n{key} metrics:")
self.log_message(f" Pixel AUROC: {metrics[0]:.4f}")
self.log_message(f" Pixel AUPRO: {metrics[1]:.4f}")
self.log_message(f" Pixel F1: {metrics[2]:.4f}")
self.log_message(f" Pixel AP: {metrics[3]:.4f}")
self.log_message(f" Image AUROC: {metrics[4]:.4f}")
self.log_message(f" Image AP: {metrics[5]:.4f}")
self.log_message(f" Image F1: {metrics[6]:.4f}")
# Save evaluation images if enabled
if getattr(args, 'save_evaluation_images', False) and self.checkpoint_manager:
category = getattr(args, 'object_category', 'unknown')
annotation_dir = getattr(args, 'annotation_dir', None)
npy_cache_dir = getattr(args, 'npy_cache_dir', None)
patch_size = getattr(args, 'patch_size', 256)
# Map CLI parameters to evaluation parameters
# anomaly_threshold -> anomaly_binary_threshold (threshold for binary mask)
# anomaly_min_area -> anomaly_pixel_num_threshold (minimum pixels for defect)
anomaly_binary_threshold = getattr(args, 'anomaly_threshold', getattr(args, 'anomaly_binary_threshold', 5))
anomaly_pixel_num_threshold = getattr(args, 'anomaly_min_area', getattr(args, 'anomaly_pixel_num_threshold', 10))
overlay_alpha = getattr(args, 'image_overlay_alpha', 0.8)
grid_thickness = getattr(args, 'grid_thickness', 1)
self.save_evaluation_images_for_category(
anomaly_maps=all_anomaly_maps,
annotation_dir=annotation_dir,
category=category,
image_paths=all_image_paths,
npy_cache_dir=npy_cache_dir,
patch_size=patch_size,
anomaly_binary_threshold=anomaly_binary_threshold,
anomaly_pixel_num_threshold=anomaly_pixel_num_threshold,
overlay_alpha=overlay_alpha,
grid_thickness=grid_thickness
)
# Create confusion matrix if enabled
if getattr(args, 'enable_confusion_matrix', False):
annotation_dir = getattr(args, 'annotation_dir', None)
if not annotation_dir:
self.log_message('Warning: Cannot generate confusion matrix without annotation_dir')
else:
self.log_message('\nGenerating confusion matrix...')
try:
import json
patch_results_by_image = {}
# If evaluation images were saved, load from JSONs
if getattr(args, 'save_evaluation_images', False) and self.checkpoint_manager:
evaluation_results_dir = self.checkpoint_manager.evaluation_results_dir
if evaluation_results_dir.exists():
for json_file in evaluation_results_dir.glob('*__evaluation.json'):
with open(json_file, 'r') as f:
eval_data = json.load(f)
image_path = eval_data.get('image_path', '')
patch_analysis = eval_data.get('patch_analysis', [])
if image_path and patch_analysis:
patch_results_by_image[image_path] = patch_analysis
# If no saved JSONs, compute patch results on the fly
if not patch_results_by_image:
predictions = all_anomaly_maps['anomaly_arithmetic']
ground_truth_map = self.load_ground_truth_map(annotation_dir)
for idx, image_path in enumerate(all_image_paths):
pred_mask = predictions[idx]
if isinstance(pred_mask, torch.Tensor):
pred_mask = pred_mask.detach().cpu().numpy()
ground_truth_defective_patches = self.get_ground_truth_for_image(
ground_truth_map, image_path
)
patch_results = self.classify_patches(
pred_mask, ground_truth_defective_patches, patch_size,
anomaly_binary_threshold, anomaly_pixel_num_threshold
)
patch_results_by_image[image_path] = patch_results
if patch_results_by_image:
# Determine output directory
if self.save_plot_path:
output_dir = str(self.save_plot_path)
elif self.checkpoint_manager:
output_dir = str(self.checkpoint_manager.results_dir)
else:
output_dir = './confusion_matrix_output'
create_confusion_matrix_from_patch_results(
patch_results_by_image=patch_results_by_image,
output_dir=output_dir,
patch_size=patch_size
)
else:
self.log_message('Warning: No evaluation results found for confusion matrix generation')
except Exception as e:
self.log_message(f'Warning: Failed to create confusion matrix: {e}')
import traceback
traceback.print_exc()
# Save visualization variants if enabled
save_image_variants = getattr(args, 'save_image_variants', None)
if save_image_variants and self.save_plot_path:
try:
from dioodmi.utils.utils import path_to_safe_filename
from dioodmi.utils.visualization_variants import create_anomaly_visualization_variants
# Set defaults
if save_image_variants is None:
save_image_variants = ['continuous', 'binary']
save_colormaps = getattr(args, 'save_colormaps', None) or ['jet']
save_normalization = getattr(args, 'save_normalization', 'minmax')
normal_center = getattr(args, 'normal_center', 125.0)
save_binary_thresholds = getattr(args, 'save_binary_thresholds', None) or [5.0]
# Handle 'all' variant option
if 'all' in save_image_variants:
save_image_variants = ['continuous', 'binary', 'grayscale', 'absolute']
# Create image_level directory for variants
image_level_dir = Path(self.save_plot_path) / "marked_images" / "image_level"
image_level_dir.mkdir(parents=True, exist_ok=True)
# Save variants for each image
for idx, image_path in enumerate(all_image_paths):
safe_name = path_to_safe_filename(image_path)
# Save variants for each map type
for map_type in ['anomaly_arithmetic', 'anomaly_geometric', 'latent_discrepancy', 'image_discrepancy']:
if map_type in all_anomaly_maps and idx < len(all_anomaly_maps[map_type]):
map_data = all_anomaly_maps[map_type][idx]
# Ensure it's a numpy array
if isinstance(map_data, torch.Tensor):
map_data = map_data.detach().cpu().numpy()
# For grayscale variant, use signed difference if available
signed_diff_data = None
if 'grayscale' in save_image_variants:
if map_type == 'image_discrepancy' and 'image_signed_diff' in all_anomaly_maps:
if idx < len(all_anomaly_maps['image_signed_diff']):
signed_diff_data = all_anomaly_maps['image_signed_diff'][idx]
if isinstance(signed_diff_data, torch.Tensor):
signed_diff_data = signed_diff_data.detach().cpu().numpy()
elif map_type == 'latent_discrepancy' and 'latent_signed_diff' in all_anomaly_maps:
if idx < len(all_anomaly_maps['latent_signed_diff']):
signed_diff_data = all_anomaly_maps['latent_signed_diff'][idx]
if isinstance(signed_diff_data, torch.Tensor):
signed_diff_data = signed_diff_data.detach().cpu().numpy()
# Create variants
variants = create_anomaly_visualization_variants(
map_data,
variants=save_image_variants,
colormaps=save_colormaps,
normalization=save_normalization,
normal_center=normal_center,
thresholds=save_binary_thresholds,
signed_diff_data=signed_diff_data # Pass signed difference for grayscale
)
# Save each variant
for variant_name, variant_img in variants.items():
variant_path = image_level_dir / f"{safe_name}__{map_type}__{variant_name}.png"
PILImage.fromarray(variant_img).save(variant_path)
except Exception as e:
print(f"Warning: Failed to save visualization variants: {e}")
# Save plots if requested
if self.save_plot_path:
self.log_message(f"Saving plots to {self.save_plot_path}")
# TODO: Implement plot saving
pass
# Wait for all images to be saved before returning
if self.save_plot_path and self.save_queue and not self._disable_file_saving:
self.wait_for_saves()
# Cleanup CUDA resources if shutdown was requested
if self._shutdown_requested:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
except Exception:
pass # Ignore errors during cleanup
print("Cleanup complete. Exiting...")
return all_anomaly_maps, all_masks, all_labels
def detect_and_evaluate_without_cache(self, args):
"""Mode 1: Fast evaluation without caching (original evaluate behavior).
Composition: Inference → Process → Aggregate
"""
# Mode detection for full_image_eval
if getattr(args, "full_image_eval", False):
return self._evaluate_full_image_tiled(args, getattr(args, 'resume', False))
# Setup
test_dataset, test_loader, diffusion = self._setup_evaluation(args)
# Composition: Inference → Process → Aggregate
inference_iter = self._run_inference_iter(args, test_dataset, test_loader, diffusion)
anomaly_map_iter = self._compute_anomaly_maps_iter(inference_iter, args, diffusion)
all_anomaly_maps, all_masks, all_labels = self._aggregate_and_compute_metrics(anomaly_map_iter, args)
return all_anomaly_maps, all_masks, all_labels
def detect_and_evaluate(self, args):
"""Mode 2: Standard workflow with detection result caching (DEFAULT).
Composition: Inference → Save → Process → Aggregate
"""
import os
# Setup
test_dataset, test_loader, diffusion = self._setup_evaluation(args)
save_dir = args.npy_cache_dir or os.path.join(args.save_plot_path, "npy_cache")
os.makedirs(save_dir, exist_ok=True)
# Composition: Inference → Save → Process → Aggregate
inference_iter = self._run_inference_iter(args, test_dataset, test_loader, diffusion)
saving_iter = self._save_fundamentals_iter(inference_iter, save_dir, args)
anomaly_map_iter = self._compute_anomaly_maps_iter(saving_iter, args, diffusion)
all_anomaly_maps, all_masks, all_labels = self._aggregate_and_compute_metrics(anomaly_map_iter, args)
self.log_message(f"\n✓ Detection results cached to: {save_dir}")
self.log_message(f"✓ Use --mode evaluate_only --npy_cache_dir {save_dir} to reprocess")
return all_anomaly_maps, all_masks, all_labels
def detect_only(self, args):
"""Mode 3: Detection only, save results for later evaluation.
Composition: Inference → Save
TTA fundamentals are now computed in _run_inference_iter and saved by _save_fundamentals_iter.
"""
import os
# Setup
test_dataset, test_loader, diffusion = self._setup_evaluation(args)
save_dir = args.npy_cache_dir or os.path.join(args.save_plot_path, "npy_cache")
os.makedirs(save_dir, exist_ok=True)
# Composition: Inference → Save (TTA fundamentals computed and saved inline)
inference_iter = self._run_inference_iter(args, test_dataset, test_loader, diffusion)
saving_iter = self._save_fundamentals_iter(inference_iter, save_dir, args)
# Consume iterator to trigger saving
count = sum(1 for _ in saving_iter)
self.log_message(f"\n✓ Detection completed: {count} batches processed")
self.log_message(f"✓ Results saved to: {save_dir}")
self.log_message(f"✓ Use --mode evaluate_only --npy_cache_dir {save_dir} to compute metrics")
def evaluate_only(self, args):
"""Mode 4: Evaluate previously saved detection results.
Composition: Load → Process → Aggregate
"""
import os
# Validate
save_dir = args.npy_cache_dir
if not os.path.exists(save_dir):
raise ValueError(f"NPY cache directory not found: {save_dir}")
# Setup (need dataset for ground truth masks)
test_dataset, test_loader, _ = self._setup_evaluation(args)
# Composition: Load → Process → Aggregate
loading_iter = self._load_fundamentals_iter_dispatcher(save_dir, test_dataset, args)
# For evaluate_only mode, we don't have diffusion from setup, so create it
_, _, diffusion = self._setup_evaluation(args)
anomaly_map_iter = self._compute_anomaly_maps_iter(loading_iter, args, diffusion)
all_anomaly_maps, all_masks, all_labels = self._aggregate_and_compute_metrics(anomaly_map_iter, args)
self.log_message(f"\n✓ Evaluation completed using cached results from: {save_dir}")
self.log_message(f"✓ Processed {len(all_masks)} images")
return all_anomaly_maps, all_masks, all_labels
def evaluate(self, args):
"""
Main evaluation method with mode dispatch.
This method processes the entire test dataset and generates anomaly detection results.
Supports 4 modes for different workflows:
- detect_and_evaluate_without_cache: Fast, no caching
- detect_and_evaluate: Standard with NPY cache (DEFAULT)
- detect_only: Save fundamentals for later
- evaluate_only: Reprocess cached fundamentals
Also supports full_image_eval mode for tiled evaluation of large images.
"""
# Wrap in try-except to catch KeyboardInterrupt
try:
# Get mode (default to detect_and_evaluate for backward compatibility)
mode = getattr(args, 'mode', 'detect_and_evaluate')
# Mode dispatch
if mode == "detect_and_evaluate_without_cache":
# Fast pipeline without caching
if getattr(args, 'full_image_eval', False):
return self._evaluate_full_image_tiled(args, getattr(args, 'enable_resume', False))
else:
return self.detect_and_evaluate_without_cache(args)
elif mode == "detect_and_evaluate":
# Standard workflow with caching (DEFAULT)
if getattr(args, 'full_image_eval', False):
raise NotImplementedError("detect_and_evaluate mode not supported with full_image_eval yet")
else:
return self.detect_and_evaluate(args)
elif mode == "detect_only":
# Detection only, save for later evaluation
if getattr(args, 'full_image_eval', False):
raise NotImplementedError("detect_only mode not supported with full_image_eval yet")
else:
return self.detect_only(args)
elif mode == "evaluate_only":
# Evaluate previously saved detection results
if getattr(args, 'full_image_eval', False):
raise NotImplementedError("evaluate_only mode not supported with full_image_eval yet")
else:
return self.evaluate_only(args)
else:
raise ValueError(
f"Unknown mode: {mode}. Must be one of: "
"detect_and_evaluate_without_cache, detect_and_evaluate, detect_only, evaluate_only"
)
except KeyboardInterrupt:
# Catch KeyboardInterrupt explicitly (may be raised during blocking operations)
# This is a fallback - the signal handler should have already set _shutdown_requested
if not self._shutdown_requested:
self.log_message("\n⚠️ KeyboardInterrupt caught. Initiating graceful shutdown...", level='warning')
self._shutdown_requested = True
finally:
# Cleanup CUDA resources
if self._shutdown_requested:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
except Exception:
pass # Ignore errors during cleanup
# Wait for pending saves before shutdown (prevents data loss)
if self.save_plot_path and self.save_queue:
queue_size = self.get_save_queue_size()
if queue_size > 0:
self.log_message(f"⏳ Waiting for {queue_size} pending save operations to complete...")
self.wait_for_saves(timeout=30) # Shorter timeout for interrupts
# Shutdown save worker if running
if self.save_queue and self.save_worker_thread:
self.shutdown_save_worker()
self.log_message("Evaluation interrupted by user. Cleanup complete.")
# Evaluation image saving methods (ported from decodiff_evaluator.py)
def create_binary_mask(self, anomaly_map, threshold=5):
"""Create binary mask from anomaly map using threshold (0-255 range)"""
# Auto-detect anomaly map range and normalize threshold accordingly
amap_max = np.max(anomaly_map)
if amap_max <= 1.0:
# Anomaly map is in 0-1 range, normalize threshold from 0-255 to 0-1
normalized_threshold = threshold / 255.0
else:
# Anomaly map is in 0-255 range, use threshold as-is
normalized_threshold = threshold
return (anomaly_map > normalized_threshold).astype(np.float32)
def count_anomaly_pixels(self, binary_mask):
"""Count number of positive pixels in binary mask"""
return int(np.sum(binary_mask))
def load_single_annotation(self, annotation_path):
"""Load a single annotation file and return its data
Args:
annotation_path: Path to the annotation JSON file
Returns:
dict: Annotation data with keys 'image_path', 'defective_patches', 'grid_size'
None: If file doesn't exist or can't be loaded
"""
import os
import json
if not os.path.exists(annotation_path):
return None
try:
with open(annotation_path, 'r') as f:
annotation = json.load(f)
return annotation
except Exception as e:
print(f"Warning: Error reading annotation file {annotation_path}: {e}")
return None
def load_ground_truth_map(self, annotation_dir):
"""Load all annotations and build ground truth map
Args:
annotation_dir: Directory containing annotation JSON files
Returns:
dict: Mapping from image_path to set of (grid_row, grid_col) tuples
"""
import os
from glob import glob
ground_truth_map = {}
if not annotation_dir or not os.path.exists(annotation_dir):
return ground_truth_map
# Find all annotation files
annotation_pattern = os.path.join(annotation_dir, "*__annotations.json")
annotation_files = glob(annotation_pattern)
for annotation_file in annotation_files:
annotation = self.load_single_annotation(annotation_file)
if annotation is None:
continue
# Get the original image path from the annotation
image_path = annotation.get("image_path")
if image_path:
# Convert defective patches to set of tuples
defective_patches = set(tuple(patch) for patch in annotation.get("defective_patches", []))
ground_truth_map[image_path] = defective_patches
return ground_truth_map
def get_ground_truth_for_image(self, ground_truth_map, image_path):
"""Get defective patches for a specific image
Args:
ground_truth_map: Dictionary mapping image paths to defective patch sets
image_path: Path to the image
Returns:
set: Set of (grid_row, grid_col) tuples for defective patches
"""
return ground_truth_map.get(image_path, set())
def generate_patch_coordinates(self, height, width, patch_size):
"""Generate 8-value coordinates for all patches in an image
Args:
height: Image height
width: Image width
patch_size: Size of each patch
Returns:
List of tuples: [(grid_row, grid_col, coords_8_values), ...]
where coords_8_values = [x1, y1, x2, y2, x3, y3, x4, y4]
representing the 4 corners of each patch
"""
n_rows = (height + patch_size - 1) // patch_size
n_cols = (width + patch_size - 1) // patch_size
patch_coordinates = []
for r in range(n_rows):
for c in range(n_cols):
# Calculate patch boundaries
y1 = r * patch_size
y2 = min(y1 + patch_size, height)
x1 = c * patch_size
x2 = min(x1 + patch_size, width)
# Create 8-value coordinates for axis-aligned rectangle
# Format: [x1, y1, x2, y2, x3, y3, x4, y4] (4 corners)
coords_8_values = [
x1, y1, # Top-left
x2, y1, # Top-right
x2, y2, # Bottom-right
x1, y2 # Bottom-left
]
patch_coordinates.append((r, c, coords_8_values))
return patch_coordinates
def extract_patch(self, image, coords_8_values):
"""Extract a single patch from image using 8-value coordinates
Args:
image: Input image as numpy array
coords_8_values: List of 8 values [x1, y1, x2, y2, x3, y3, x4, y4]
Returns:
Extracted patch as numpy array
"""
x1, y1, x2, y2, x3, y3, x4, y4 = coords_8_values
# Check if patch is parallel to image axes (faster extraction)
is_parallel = (y1 == y2 and y3 == y4 and x1 == x4 and x2 == x3)
if is_parallel:
# Fast path: rectangular patch aligned with image axes
# Use top-left and bottom-right corners
x_min, y_min = int(x1), int(y1)
x_max, y_max = int(x3), int(y3)
# Ensure coordinates are within image bounds
height, width = image.shape[:2]
x_min = max(0, min(x_min, width - 1))
y_min = max(0, min(y_min, height - 1))
x_max = max(x_min + 1, min(x_max, width))
y_max = max(y_min + 1, min(y_max, height))
# Extract patch using array slicing
patch = image[y_min:y_max, x_min:x_max]
return patch
else:
# TODO: Slow path for rotated patches - implement perspective transform
# For now, raise an error to indicate this is not yet supported
raise NotImplementedError("Rotated/perspective patches not yet supported")
def classify_single_patch(self, pred_patch, grid_row, grid_col, ground_truth_defective_patches,
anomaly_binary_threshold=5, anomaly_pixel_num_threshold=10):
"""Classify a single patch as TP/FP/FN/TN using annotation-based ground truth
Args:
pred_patch: Prediction patch as numpy array
grid_row: Row position in grid
grid_col: Column position in grid
ground_truth_defective_patches: Set of (row, col) tuples representing defective patches
anomaly_binary_threshold: Threshold for creating binary mask (0-255)
anomaly_pixel_num_threshold: Minimum pixels needed for defect classification
Returns:
dict: Classification result with keys:
- status: 'TP', 'FP', 'FN', or 'TN'
- anomaly_pixels: Number of anomaly pixels found
- pred_score: Mean prediction score
- gt_present: Whether ground truth has defects
"""
# Create binary mask and count anomaly pixels
binary_mask = self.create_binary_mask(pred_patch, anomaly_binary_threshold)
anomaly_pixels = self.count_anomaly_pixels(binary_mask)
# Determine if patch is predicted as defective using pixel counting
pred_defective = anomaly_pixels > anomaly_pixel_num_threshold
# Determine if ground truth has defects using annotations
gt_defective = (grid_row, grid_col) in ground_truth_defective_patches
# Classify patch
if pred_defective and gt_defective:
status = "TP"
elif pred_defective and not gt_defective:
status = "FP"
elif not pred_defective and gt_defective:
status = "FN"
else:
status = "TN"
return {
'status': status,
'anomaly_pixels': anomaly_pixels,
'pred_score': float(np.mean(pred_patch)),
'gt_present': bool(gt_defective)
}
def classify_patches(self, pred_mask, ground_truth_defective_patches, patch_size,
anomaly_binary_threshold=5, anomaly_pixel_num_threshold=10):
"""Classify patches for a single image into TP/FP/FN/TN categories using annotations
This orchestrator function uses the separate patch coordinate generation,
extraction, and classification functions to follow single responsibility principle.
Args:
pred_mask: Prediction mask as numpy array
ground_truth_defective_patches: Set of (row, col) tuples representing defective patches
patch_size: Size of patches to extract
anomaly_binary_threshold: Threshold for creating binary mask (0-255)
anomaly_pixel_num_threshold: Minimum pixels needed for defect classification
"""
h, w = pred_mask.shape
# Generate coordinates for all patches using pure function
patch_coordinates = self.generate_patch_coordinates(h, w, patch_size)
patch_results = []
for grid_row, grid_col, coords_8_values in patch_coordinates:
# Extract prediction patch using pure function
pred_patch = self.extract_patch(pred_mask, coords_8_values)
# Classify single patch using annotation-based ground truth
classification_result = self.classify_single_patch(
pred_patch, grid_row, grid_col, ground_truth_defective_patches,
anomaly_binary_threshold, anomaly_pixel_num_threshold
)
# Add grid position and coordinate information
# Note: coords_8_values[0:2] gives top-left corner (x1, y1)
x1, y1 = coords_8_values[0], coords_8_values[1]
patch_result = {
'grid_row': int(grid_row),
'grid_col': int(grid_col),
'coords': (int(x1), int(y1)),
**classification_result # Unpack status, anomaly_pixels, pred_score, gt_present
}
patch_results.append(patch_result)
return patch_results
def save_evaluation_images_for_category(self, anomaly_maps: dict, annotation_dir: str,
category: str, image_paths: list,
npy_cache_dir: str = None,
patch_size: int = 256,
anomaly_binary_threshold: int = 5,
anomaly_pixel_num_threshold: int = 10,
overlay_alpha: float = 0.8, grid_thickness: int = 1):
"""Save evaluation images for a single category using existing confusion matrix pipeline"""
self.log_message('\nSaving evaluation images...')
# Use the arithmetic anomaly map as the primary prediction
predictions = anomaly_maps['anomaly_arithmetic']
# Load ground truth annotations if available
ground_truth_map = self.load_ground_truth_map(annotation_dir) if annotation_dir else {}
# Load NPY manifest if npy_cache_dir is provided
npy_manifest = {}
if npy_cache_dir:
manifest_path = os.path.join(npy_cache_dir, 'fundamentals_manifest.csv')
if os.path.exists(manifest_path):
import csv
with open(manifest_path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
# Map both original_file_path and full_file_path to npy_filename for flexible lookup
# Handles both simple names ("000") and full paths ("wood/test/hole/000.png")
original_path = row['original_file_path']
full_path = row.get('full_file_path', '')
npy_filename = row['npy_filename']
npy_full_path = os.path.join(npy_cache_dir, npy_filename)
# Add multiple lookup keys for robust path matching
npy_manifest[original_path] = npy_full_path
if full_path:
npy_manifest[full_path] = npy_full_path
# Also add basename without extension
full_basename = os.path.splitext(os.path.basename(full_path))[0]
npy_manifest[full_basename] = npy_full_path
self.log_message(f"Processing {len(predictions)} images for {category}...")
# Process each image
for idx, (pred_mask, image_path) in enumerate(zip(predictions, image_paths)):
# Ensure prediction mask is numpy array
if isinstance(pred_mask, torch.Tensor):
pred_mask = pred_mask.detach().cpu().numpy()
# Get ground truth defective patches for this image
ground_truth_defective_patches = self.get_ground_truth_for_image(
ground_truth_map, image_path
)
# Classify patches for this image using annotations
patch_results = self.classify_patches(
pred_mask, ground_truth_defective_patches, patch_size,
anomaly_binary_threshold, anomaly_pixel_num_threshold
)
# Convert patch results to sets for visualization
predicted_defective_set = set()
overlapping_set = set()
for result in patch_results:
grid_row = result['grid_row']
grid_col = result['grid_col']
status = result['status']
if status in ['TP', 'FP']: # Predicted as defective
predicted_defective_set.add((grid_row, grid_col))
if status == 'TP': # True positive = overlapping
overlapping_set.add((grid_row, grid_col))
# Create combined anomaly maps for saving
combined_anomaly_maps = {
'arithmetic': anomaly_maps['anomaly_arithmetic'][idx] if 'anomaly_arithmetic' in anomaly_maps else pred_mask,
'geometric': anomaly_maps['anomaly_geometric'][idx] if 'anomaly_geometric' in anomaly_maps else pred_mask,
'latent': anomaly_maps['latent_discrepancy'][idx] if 'latent_discrepancy' in anomaly_maps else pred_mask,
'image': anomaly_maps['image_discrepancy'][idx] if 'image_discrepancy' in anomaly_maps else pred_mask
}
# Save evaluation images using the real image path
try:
self.save_evaluation_images(
image_path=image_path,
anomaly_maps=combined_anomaly_maps,
patch_results=patch_results,
predicted_defective_set=predicted_defective_set,
ground_truth_defective_set=ground_truth_defective_patches,
overlapping_set=overlapping_set,
npy_manifest=npy_manifest,
npy_cache_dir=npy_cache_dir,
patch_size=patch_size,
anomaly_binary_threshold=anomaly_binary_threshold,
overlay_alpha=overlay_alpha,
grid_thickness=grid_thickness,
save_raw_anomaly_maps=getattr(self, 'save_raw_anomaly_maps', False),
save_image_variants=getattr(self, 'save_image_variants', None),
save_colormaps=getattr(self, 'save_colormaps', None),
save_normalization=getattr(self, 'save_normalization', 'minmax'),
normal_center=getattr(self, 'normal_center', 125.0),
save_binary_thresholds=getattr(self, 'save_binary_thresholds', None)
)
except Exception as e:
print(f"Warning: Could not save images for {image_path}: {e}")
self.log_message(f"Evaluation images saved for category: {category}")
def save_evaluation_images(self, image_path: str, anomaly_maps: dict, patch_results: list,
predicted_defective_set: set, ground_truth_defective_set: set,
overlapping_set: set, npy_manifest: dict = None,
npy_cache_dir: str = None,
patch_size: int = 256,
anomaly_binary_threshold: int = 5,
overlay_alpha: float = 0.8, grid_thickness: int = 1,
save_raw_anomaly_maps: bool = False,
save_image_variants: Optional[List[str]] = None,
save_colormaps: Optional[List[str]] = None,
save_normalization: str = 'minmax',
normal_center: float = 125.0,
save_binary_thresholds: Optional[List[float]] = None):
"""Save all evaluation images in organized subfolder structure"""
import os
from PIL import Image as PILImage
if not self.checkpoint_manager:
return
# Determine image status from patch results
image_status = determine_image_status(patch_results)
safe_name = path_to_safe_filename(image_path)
# Load original image from NPY cache or disk
original_img = None
# Try loading from NPY cache first if manifest is provided
if npy_manifest:
# Try multiple lookup strategies for robust path matching
lookup_keys = [
image_path, # Full path as provided
os.path.splitext(os.path.basename(image_path))[0], # Basename without extension
os.path.basename(image_path) # Basename with extension
]
npy_path = None
for key in lookup_keys:
if key in npy_manifest:
npy_path = npy_manifest[key]
break
if npy_path:
try:
# Handle both .npy and .npz files
if npy_path.endswith('.npz'):
# Load NPZ archive and extract 'original_image' key
with np.load(npy_path) as data:
if 'original_image' in data:
original_tensor = data['original_image'] # Shape: (C, H, W), dtype: float32
else:
raise KeyError(f"'original_image' key not found in {npy_path}")
else:
# Load regular NPY file
original_tensor = np.load(npy_path) # Shape: (C, H, W), dtype: float32
# Convert from (C, H, W) to (H, W, C) and rescale to [0, 255]
original_img = np.transpose(original_tensor, (1, 2, 0)) # (H, W, C)
# Rescale from [-1, 1] or [0, 1] to [0, 255] with proper rounding
if original_img.min() < 0: # Likely in [-1, 1] range
original_img = np.clip(np.round((original_img + 1.0) * 127.5), 0, 255).astype(np.uint8)
else: # Likely in [0, 1] range
original_img = np.clip(np.round(original_img * 255.0), 0, 255).astype(np.uint8)
except Exception as e:
print(f"Warning: Could not load NPY/NPZ image from {npy_path}: {e}")
# Fallback to loading from disk if NPY failed or not available
if original_img is None:
try:
if os.path.exists(image_path):
original_img = np.array(PILImage.open(image_path).convert('RGB'))
except Exception as e:
print(f"Warning: Could not load image from disk {image_path}: {e}")
# Final fallback: Search NPZ files in cache directory to resurrect original image
if original_img is None and npy_cache_dir and os.path.exists(npy_cache_dir):
try:
import glob
# Try to find any NPZ file in the cache directory
npz_files = glob.glob(os.path.join(npy_cache_dir, '*.npz'))
# Try to match by image index or basename
image_basename = os.path.splitext(os.path.basename(image_path))[0]
for npz_file in npz_files:
try:
with np.load(npz_file) as data:
if 'original_image' in data:
# Check if this might be the right file by examining the filename
npz_basename = os.path.splitext(os.path.basename(npz_file))[0]
# Match if the image basename appears in the NPZ filename
if image_basename in npz_basename or npz_basename.startswith(image_basename):
original_tensor = data['original_image']
# Convert from (C, H, W) to (H, W, C) and rescale to [0, 255]
original_img = np.transpose(original_tensor, (1, 2, 0))
# Rescale from [-1, 1] or [0, 1] to [0, 255] with proper rounding
if original_img.min() < 0:
original_img = np.clip(np.round((original_img + 1.0) * 127.5), 0, 255).astype(np.uint8)
else:
original_img = np.clip(np.round(original_img * 255.0), 0, 255).astype(np.uint8)
print(f"✓ Resurrected original image from cache: {npz_file}")
break
except Exception as inner_e:
continue # Try next NPZ file
if original_img is None:
raise RuntimeError(f"Could not find original image in NPZ cache for {image_path}")
except Exception as e:
raise RuntimeError(f"Failed to resurrect original image from cache for {image_path}: {e}")
# Save original image (without any overlays or markings)
self._save_to_subfolders(original_img, safe_name, image_status, 'original', 'original.png')
# Save marked image (with patch rectangles)
marked_img = draw_patch_rectangles_on_image(
original_img, predicted_defective_set,
ground_truth_defective_set, overlapping_set,
patch_size, grid_thickness
)
self._save_to_subfolders(marked_img, safe_name, image_status, 'marked', 'marked.png')
# Set defaults for variant saving
if save_image_variants is None:
save_image_variants = ['continuous', 'binary']
if save_colormaps is None:
save_colormaps = ['jet'] # Default for eval-evaluator
if save_binary_thresholds is None:
save_binary_thresholds = [anomaly_binary_threshold]
# Handle 'all' variant option
if 'all' in save_image_variants:
save_image_variants = ['continuous', 'binary', 'grayscale', 'absolute']
# Check if we should use new variant system
use_variant_system = (len(save_image_variants) > 2 or
len(save_colormaps) > 1 or
len(save_binary_thresholds) > 1 or
'grayscale' in save_image_variants or
'absolute' in save_image_variants)
# Save each type of anomaly map
for map_type in ['arithmetic', 'geometric', 'latent', 'image']:
if map_type not in anomaly_maps:
continue
map_data = anomaly_maps[map_type]
# Ensure it's a numpy array
if isinstance(map_data, torch.Tensor):
map_data = map_data.detach().cpu().numpy()
if use_variant_system:
# Use new variant system
from dioodmi.utils.visualization_variants import create_anomaly_visualization_variants
# Create variants
variants = create_anomaly_visualization_variants(
map_data,
variants=save_image_variants,
colormaps=save_colormaps,
normalization=save_normalization,
normal_center=normal_center,
thresholds=save_binary_thresholds
)
# Save each variant
for variant_name, variant_img in variants.items():
# Save to status folder
self._save_to_subfolders(
variant_img, safe_name, image_status,
f'anomaly_maps/{map_type}', f'{map_type}__{variant_name}.png'
)
else:
# Use legacy system (backward compatible)
# Save raw anomaly map
anomaly_img = create_anomaly_visualization(
map_data, is_binary=False
)
self._save_to_subfolders(
anomaly_img, safe_name, image_status,
f'anomaly_maps/{map_type}', f'{map_type}.png'
)
# Save binary version and overlays for arithmetic/geometric
if map_type in ['arithmetic', 'geometric']:
# Save binary version
binary_img = create_anomaly_visualization(
map_data, is_binary=True, threshold=anomaly_binary_threshold
)
self._save_to_subfolders(
binary_img, safe_name, image_status,
f'binary/{map_type}', f'{map_type}_binary.png'
)
# Save overlay
overlay_img = create_anomaly_overlay(
original_img, map_data, alpha=overlay_alpha,
threshold=anomaly_binary_threshold
)
self._save_to_subfolders(
overlay_img, safe_name, image_status,
f'overlays/{map_type}', f'{map_type}_overlay.png'
)
# Save marked overlay
marked_overlay = draw_patch_rectangles_on_image(
overlay_img, predicted_defective_set,
ground_truth_defective_set, overlapping_set,
patch_size, grid_thickness
)
self._save_to_subfolders(
marked_overlay, safe_name, image_status,
'overlays/marked', f'{map_type}_marked_overlay.png'
)
# Save raw anomaly maps as NPY files if enabled
if save_raw_anomaly_maps and self.checkpoint_manager:
try:
npy_save_dir = Path(self.checkpoint_manager.results_dir) / "anomaly_maps_npy"
npy_save_dir.mkdir(parents=True, exist_ok=True)
# Save raw anomaly maps (before any visualization)
for map_type in ['arithmetic', 'geometric', 'latent', 'image']:
if map_type in anomaly_maps:
# Ensure it's a numpy array
map_data = anomaly_maps[map_type]
if isinstance(map_data, torch.Tensor):
map_data = map_data.detach().cpu().numpy()
# Save raw array
np.save(npy_save_dir / f"{safe_name}__{map_type}.npy", map_data)
# Save binary version if applicable
if map_type in ['arithmetic', 'geometric']:
binary_map = (map_data > anomaly_binary_threshold).astype(np.float32)
np.save(npy_save_dir / f"{safe_name}__{map_type}_binary.npy", binary_map)
except Exception as e:
print(f"Warning: Failed to save raw anomaly maps for {image_path}: {e}")
# Create and save concatenated 4-panel images (original, overlay, anomaly_map, binary)
try:
for map_type in ['arithmetic', 'geometric']:
if map_type not in anomaly_maps:
continue
map_data = anomaly_maps[map_type]
if isinstance(map_data, torch.Tensor):
map_data = map_data.detach().cpu().numpy()
# Create the 4 panels
# 1. Original image
panel_original = original_img
# 2. Overlay image (original + anomaly map overlay)
panel_overlay = create_anomaly_overlay(
original_img, map_data, alpha=overlay_alpha,
threshold=anomaly_binary_threshold
)
# 3. Anomaly map (continuous with colormap)
panel_anomaly_map = create_anomaly_visualization(
map_data, is_binary=False
)
# 4. Binary image
panel_binary = create_anomaly_visualization(
map_data, is_binary=True, threshold=anomaly_binary_threshold
)
# Ensure all panels have the same height
h, w, _ = panel_original.shape
panel_overlay = self._resize_if_needed(panel_overlay, h, w)
panel_anomaly_map = self._resize_if_needed(panel_anomaly_map, h, w)
panel_binary = self._resize_if_needed(panel_binary, h, w)
# Concatenate horizontally: [overlay | original | binary | anomaly_map]
concatenated_img = np.concatenate([
panel_overlay, panel_original, panel_binary, panel_anomaly_map
], axis=1)
# Save concatenated image
self._save_to_subfolders(
concatenated_img, safe_name, image_status,
'concatenated', f'{map_type}_concatenated.png'
)
except Exception as e:
print(f"Warning: Failed to create concatenated images for {image_path}: {e}")
# Save evaluation results JSON
self._save_evaluation_json(image_path, patch_results, safe_name, patch_size)
def _resize_if_needed(self, image: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
"""Resize image if dimensions don't match target size"""
from PIL import Image as PILImage
h, w = image.shape[:2]
if h != target_h or w != target_w:
pil_img = PILImage.fromarray(image)
pil_img = pil_img.resize((target_w, target_h), PILImage.LANCZOS)
return np.array(pil_img)
return image
def _save_to_subfolders(self, image_array: np.ndarray, base_name: str,
status: str, subfolder: str, suffix: str):
"""Helper to save image to both status and image_level folders"""
from PIL import Image as PILImage
# Save to status-specific subfolder (create directory on-demand)
status_folder_path = self.checkpoint_manager.get_status_folder(status) / subfolder
status_folder_path.mkdir(parents=True, exist_ok=True)
status_path = status_folder_path / f"{base_name}_{suffix}"
PILImage.fromarray(image_array).save(status_path)
# Also save to image_level (without status classification)
image_level_folder_path = self.checkpoint_manager.get_image_level_folder() / subfolder
image_level_folder_path.mkdir(parents=True, exist_ok=True)
image_level_path = image_level_folder_path / f"{base_name}_{suffix}"
PILImage.fromarray(image_array).save(image_level_path)
def _save_evaluation_json(self, image_path: str, patch_results: list,
safe_name: str, patch_size: int):
"""Save evaluation results as JSON"""
import json
patch_analysis = []
for result in patch_results:
x, y = result.get('coords', (0, 0))
patch_analysis.append({
"grid_row": result.get('grid_row', 0),
"grid_col": result.get('grid_col', 0),
"anomaly_pixels": result.get('anomaly_pixels', 0),
"pred_score": result.get('pred_score', 0.0),
"status": result.get('status', 'TN')
})
# Ensure evaluation results directory exists
self.checkpoint_manager.evaluation_results_dir.mkdir(parents=True, exist_ok=True)
result_filename = f"{safe_name}__evaluation.json"
result_path = self.checkpoint_manager.evaluation_results_dir / result_filename
evaluation_result = {
"image_path": image_path,
"patch_analysis": patch_analysis,
"grid_size": patch_size
}
with open(result_path, 'w') as f:
json.dump(evaluation_result, f, indent=2)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""HAEDOSA Automation Engine CLI for DiOodMi.
Unified command-line interface orchestrating DiOodMi operations.
Follows HAEDOSA MS12 automation standards with export-based architecture.
HAE provides:
- Orchestration of core verbs (train, eval, reconstruct)
- Workflow automation (benchmarking, multi-category training)
- Example management
- Command export (shows standalone equivalents)
"""
import os
import subprocess
import sys
from pathlib import Path
# Configure UTF-8 encoding for all I/O operations (Windows compatibility)
# This must be set before any console output
if sys.platform.startswith('win'):
os.environ.setdefault('PYTHONIOENCODING', 'utf-8')
# Set Windows console code page to UTF-8
try:
import ctypes
kernel32 = ctypes.windll.kernel32
kernel32.SetConsoleOutputCP(65001) # UTF-8 code page
kernel32.SetConsoleCP(65001)
except Exception:
pass # Fallback if ctypes fails
# Reconfigure stdout/stderr to UTF-8
try:
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except AttributeError:
# Python < 3.7 fallback
import io
if not isinstance(sys.stdout, io.TextIOWrapper):
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
if not isinstance(sys.stderr, io.TextIOWrapper):
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
import click
from click.core import Command
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn
from .export import show_standalone_alternatives
# Configure Console with UTF-8 support for Windows
# Disable legacy Windows rendering to avoid Unicode encoding errors with emojis
console = Console(legacy_windows=False, file=sys.stdout)
def preprocess_latent_sizes_args(args):
"""Preprocess command line arguments to handle --latent_sizes with multiple space-separated values.
Converts: --latent_sizes 16 32 64 --ncls ...
To: --latent_sizes 16 --latent_sizes 32 --latent_sizes 64 --ncls ...
"""
result = []
i = 0
while i < len(args):
if args[i] == '--latent_sizes' and i + 1 < len(args):
# Found --latent_sizes, collect all following non-option args
i += 1
# Collect all following non-option arguments until we hit another option
while i < len(args) and not args[i].startswith('--'):
result.append('--latent_sizes')
result.append(args[i])
i += 1
# Don't increment i here, as we want to process the next option
continue
result.append(args[i])
i += 1
return result
def parse_directions(ctx, param, value):
"""Callback to parse directions option, handling both space-separated string and multiple flags.
With multiple=True, Click collects all following non-option arguments into a tuple.
This callback processes that tuple to handle space-separated values.
"""
if not value:
return value
# With multiple=True, Click should collect all following non-option args as a tuple
# e.g., --directions h v diag becomes ('h', 'v', 'diag')
if isinstance(value, tuple):
# Filter out empty strings and strip whitespace
result = tuple(d.strip() for d in value if d and d.strip())
return result if result else value
# If it's a string (from quoted input like --directions "h v diag"), split it
if isinstance(value, str):
result = tuple(d.strip() for d in value.split() if d.strip())
return result if result else (value,)
return value
@click.group()
@click.version_option()
def main():
"""HAEDOSA Automation Engine for DiOodMi.
Orchestrates diffusion-based anomaly detection workflows.
"""
pass
@main.command('train')
@click.option('--model', type=click.Choice(['ddpm', 'ldm', 'decodiff']),
default='decodiff', help='Model type to train')
@click.option('--dataset', type=click.Choice(['mvtec', 'visa', 'custom']),
default='mvtec', help='Dataset to use')
@click.option('--object-category', '--category', 'category',
help='Dataset category (e.g., bottle, wood, all)')
@click.option('--epochs', type=int, default=800, help='Number of training epochs')
@click.option('--batch-size', type=int, help='Global batch size')
@click.option('--data-dir', type=click.Path(exists=True),
help='Path to dataset directory')
@click.option('--csv-split-file', type=click.Path(exists=True),
help='CSV file defining custom dataset splits')
@click.option('--output-dir', type=click.Path(), default='./results',
help='Output directory for checkpoints')
@click.option('--resume-dir', type=click.Path(exists=True),
help='Directory containing checkpoint to resume from')
@click.option('--model-size', type=click.Choice(['UNet_XS', 'UNet_S', 'UNet_M', 'UNet_L', 'UNet_XL']),
help='Model size (for DeCo-Diff)')
@click.option('--image-size', type=int, help='Image size')
@click.option('--crop-size', type=int, help='Center crop size')
@click.option('--center-crop/--no-center-crop', default=None,
help='Use center crop (default: auto-detect based on image/crop size)')
@click.option('--num-workers', type=int, default=0,
help='Number of data loader workers')
@click.option('--image-loading-strategy', type=click.Choice(['resize_first', 'keep_original', 'adaptive']),
help='Image loading strategy (resize_first, keep_original, or adaptive)')
@click.option('--config', type=click.Path(exists=True),
help='Config file path (not yet implemented)')
@click.option('--distributed/--no-distributed', default=False,
help='Use distributed training via torchrun')
@click.option('--gpus', type=int, default=1,
help='Number of GPUs for distributed training')
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
@click.option('--model-name', type=str,
help='Model name for checkpoint directory')
@click.option('--ckpt-every', type=int,
help='Save checkpoint every N epochs')
@click.option('--bootstrap-samples', type=int,
help='Number of bootstrap samples')
@click.option('--bootstrap-seed', type=int,
help='Bootstrap random seed')
@click.option('--first-n', type=int,
help='Take first N samples')
@click.option('--last-n', type=int,
help='Take last N samples')
@click.option('--range-start', type=int,
help='Start index for range sampling')
@click.option('--range-count', type=int,
help='Count for range sampling')
@click.option('--random-n', type=int,
help='Random N samples (without replacement)')
@click.option('--random-n-with-replacement', type=int,
help='Random N samples (with replacement)')
@click.option('--random-seed', type=int,
help='Random seed for sampling')
def train(model, dataset, category, epochs, batch_size, data_dir, csv_split_file, output_dir,
resume_dir, model_size, image_size, crop_size, center_crop, num_workers,
image_loading_strategy, config, distributed, gpus, dry_run,
model_name, ckpt_every, bootstrap_samples, bootstrap_seed,
first_n, last_n, range_start, range_count, random_n, random_n_with_replacement, random_seed):
"""Train a diffusion model for anomaly detection.
Orchestrates dioodmi-train with appropriate configuration.
Examples:
hae train --model decodiff --dataset mvtec --category bottle
hae train --model decodiff --distributed --gpus 2 --dry-run
"""
console.print(f"[bold green]🏋️ Training {model.upper()} on {dataset}[/bold green]\n")
# Build command (platform-specific)
# Windows: use system Python (py -3.11 -m) to access system-installed packages
# Linux/uv: use dioodmi-train command from uv environment
if sys.platform.startswith('win'):
cmd = ["py", "-3.11", "-m", "dioodmi.cli.train"]
else:
cmd = ["dioodmi-train"]
cmd.extend(["--model", model, "--dataset", dataset])
if category:
cmd.extend(["--object_category", category])
if epochs:
cmd.extend(["--epochs", str(epochs)])
if batch_size:
cmd.extend(["--global_batch_size", str(batch_size)])
if data_dir:
cmd.extend(["--data_dir", data_dir])
if csv_split_file:
cmd.extend(["--csv_split_file", csv_split_file])
if output_dir:
cmd.extend(["--output_dir", output_dir])
if resume_dir:
cmd.extend(["--resume_dir", resume_dir])
if model_size:
cmd.extend(["--model_size", model_size])
if image_size:
cmd.extend(["--image_size", str(image_size)])
if crop_size:
cmd.extend(["--crop_size", str(crop_size)])
if center_crop is not None:
cmd.extend(["--center_crop", "true" if center_crop else "false"])
if num_workers is not None:
cmd.extend(["--num_workers", str(num_workers)])
if image_loading_strategy:
cmd.extend(["--image_loading_strategy", image_loading_strategy])
if model_name:
cmd.extend(["--model_name", model_name])
if ckpt_every is not None:
cmd.extend(["--ckpt_every", str(ckpt_every)])
if bootstrap_samples is not None:
cmd.extend(["--bootstrap-samples", str(bootstrap_samples)])
if bootstrap_seed is not None:
cmd.extend(["--bootstrap-seed", str(bootstrap_seed)])
if first_n is not None:
cmd.extend(["--first-n", str(first_n)])
if last_n is not None:
cmd.extend(["--last-n", str(last_n)])
if range_start is not None:
cmd.extend(["--range-start", str(range_start)])
if range_count is not None:
cmd.extend(["--range-count", str(range_count)])
if random_n is not None:
cmd.extend(["--random-n", str(random_n)])
if random_n_with_replacement is not None:
cmd.extend(["--random-n-with-replacement", str(random_n_with_replacement)])
if random_seed is not None:
cmd.extend(["--random-seed", str(random_seed)])
# Show standalone alternatives
show_standalone_alternatives(cmd, f"Train {model.upper()} on {dataset}" + (f"/{category}" if category else ""))
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
# Execute command
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
if distributed:
# Use torchrun for distributed training
torchrun_cmd = [
"torchrun",
f"--nproc_per_node={gpus}",
"--nnodes=1",
"--node_rank=0"
] + cmd
console.print(f"[yellow]Using distributed training on {gpus} GPUs[/yellow]")
console.print(f"[dim]Command: {' '.join(torchrun_cmd)}[/dim]\n")
result = subprocess.run(torchrun_cmd)
else:
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.command('eval')
# Config file option (FIRST - allows loading all parameters from JSON)
@click.option('--config', type=click.Path(exists=True),
help='JSON config file (base + environment overrides)')
# Core parameters
@click.option('--model-path', '--checkpoint', 'model_path',
type=click.Path(exists=True),
help='Path to trained model checkpoint')
@click.option('--dataset', type=click.Choice(['mvtec', 'mvtec2', 'sdp', 'sdp_orig', 'sdp_orig_test',
'sdp_ok_ng', 'sdp_ok_ng_crop', 'sdp_ok_ng_resize_1024',
'sdp_ok_ng_resize_2048', 'visa', 'pcb', 'avp', 'iris']),
help='Dataset to evaluate')
@click.option('--object-category', '--category', 'category',
help='Dataset category')
@click.option('--anomaly-class',
help='Anomaly class filter')
@click.option('--data-dir', type=click.Path(exists=True),
help='Path to dataset directory')
@click.option('--csv-split-file',
help='Custom CSV split file (overrides dataset default)')
# Model configuration
@click.option('--model-size', type=click.Choice(['UNet_XS','UNet_S','UNet_M','UNet_L','UNet_XL']),
help='Model size')
@click.option('--vae-type', type=click.Choice(['ema', 'mse']),
help='VAE type (ema or mse)')
@click.option('--image-size', type=int,
help='Image size (integer)')
@click.option('--crop-size', type=int,
help='Crop size')
@click.option('--center-crop', is_flag=True, default=False,
help='Enable center crop')
# Evaluation parameters
# Strategy parameters
@click.option('--augmentation-strategy',
type=click.Choice(['center_crop', 'center_crop_with_affine',
'random_crop', 'multi_stage', 'none']),
help='Augmentation strategy for evaluation')
@click.option('--image-loading-strategy',
type=click.Choice(['resize_first', 'keep_original', 'adaptive']),
help='Image loading strategy for evaluation')
@click.option('--reverse-steps', type=int,
help='Number of reverse diffusion steps')
@click.option('--batch-size', type=int,
help='Batch size for evaluation')
@click.option('--num-workers', type=int,
help='Number of DataLoader workers')
# Anomaly detection parameters
@click.option('--anomaly-threshold', type=float,
help='Threshold for anomaly detection (0-255 range)')
@click.option('--anomaly-min-area', type=int,
help='Minimum area for detected anomaly regions')
# Confusion matrix parameters
@click.option('--enable-confusion-matrix', is_flag=True, default=None,
help='Enable confusion matrix evaluation')
@click.option('--patch-size', type=int,
help='Patch size for confusion matrix')
@click.option('--anomaly-binary-threshold', type=int,
help='Binary threshold for anomaly detection')
@click.option('--anomaly-pixel-num-threshold', type=int,
help='Minimum pixel count for anomaly')
@click.option('--annotation-dir', type=click.Path(),
help='Directory containing annotation JSON files')
# Output parameters
@click.option('--save-plot-path', '--output-dir', 'output_dir',
type=click.Path(),
help='Results output directory')
@click.option('--save-evaluation-images', is_flag=True, default=None,
help='Save evaluation images in organized subfolders')
@click.option('--save-reconstructions', is_flag=True, default=None,
help='Save reconstructed images (VAE and Deco-Diff) as PNG')
@click.option('--save-reconstructions-npy', is_flag=True, default=None,
help='Save raw reconstruction arrays as NPY files')
@click.option('--save-raw-anomaly-maps', is_flag=True, default=None,
help='Save raw anomaly map arrays as NPY files (image-level)')
@click.option('--save-image-variants', type=str, multiple=True,
help='Which image variants to save (can specify multiple: continuous, binary, grayscale, absolute, all)')
@click.option('--save-colormaps', type=str, multiple=True,
help='Colormaps to use for continuous maps (can specify multiple: jet, hot, viridis, plasma, magma, inferno, turbo)')
@click.option('--save-normalization', type=click.Choice(['minmax', 'zscore', 'percentile', 'raw']),
help='Normalization method for visualization')
@click.option('--normal-center', type=float,
help='Center value for grayscale normalization (default: 125.0)')
@click.option('--save-binary-thresholds', type=float, multiple=True,
help='Threshold values for binary maps (can specify multiple, e.g., 5.0 10.0 15.0)')
@click.option('--save-intermediate-images', is_flag=True, default=None,
help='Save intermediate images (orig_decoded, recon)')
@click.option('--save-npy-files', is_flag=True, default=False,
help='Save anomaly maps as .npy files')
@click.option('--npy-save-dir', type=click.Path(),
help='Directory to save .npy files')
# Full image evaluation
@click.option('--full-image-eval', is_flag=True, default=None,
help='Evaluate whole image by tiling/cropping and stitching')
@click.option('--tile-overlap', type=int,
help='Overlap pixels between tiles for blending')
@click.option('--tile-batch', type=int,
help='How many tiles to process at once on GPU')
# Shift augmentation parameters
@click.option('--pad-px', type=int,
help='Pixel padding for shift augmentation')
@click.option('--stride', type=int,
help='Stride for shift augmentation')
@click.option('--directions', type=str, multiple=True, callback=parse_directions,
help='Shift directions (space-separated: h v diag all rhombus, or use multiple --directions flags)')
@click.option('--shift-method', type=str,
help='Shift method (interpolate, mirror, etc.)')
@click.option('--fuse-method', type=str,
help='Fuse method (mean, pct25, etc.)')
@click.option('--img-enlarge-px', type=int,
help='Image enlargement in pixels')
# TTA fundamentals saving
@click.option('--save-tta-fundamentals', is_flag=True, default=False,
help='Save per-shift TTA fundamentals (x0, encoded, image_samples, latent_samples)')
@click.option('--tta-cache-dir', type=str, default=None,
help='Directory to save TTA fundamentals NPY files and CSV manifest')
# Dataset sampling options
@click.option('--bootstrap-samples', type=int, default=None,
help='Bootstrap sampling: number of samples (with replacement)')
@click.option('--bootstrap-seed', type=int, default=42,
help='Random seed for bootstrap sampling')
@click.option('--first-n', type=int, default=None,
help='Take first N samples from dataset')
@click.option('--last-n', type=int, default=None,
help='Take last N samples from dataset')
@click.option('--range-start', type=int, default=None,
help='Range sampling: start index (0-based)')
@click.option('--range-count', type=int, default=10,
help='Range sampling: number of samples (default: 10)')
@click.option('--random-n', type=int, default=None,
help='Random N samples without replacement')
@click.option('--random-n-with-replacement', type=int, default=None,
help='Random N samples with replacement')
@click.option('--random-seed', type=int, default=42,
help='Random seed for random sampling methods')
# Utility
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
def eval(config, model_path, dataset, category, anomaly_class, data_dir, csv_split_file,
model_size, vae_type, image_size, crop_size, center_crop,
augmentation_strategy, image_loading_strategy,
reverse_steps, batch_size, num_workers,
anomaly_threshold, anomaly_min_area,
enable_confusion_matrix, patch_size, anomaly_binary_threshold,
anomaly_pixel_num_threshold, annotation_dir,
output_dir, save_evaluation_images, save_reconstructions, save_reconstructions_npy,
save_raw_anomaly_maps, save_image_variants, save_colormaps, save_normalization,
normal_center, save_binary_thresholds,
save_intermediate_images, save_npy_files, npy_save_dir,
full_image_eval, tile_overlap, tile_batch,
pad_px, stride, directions, shift_method, fuse_method, img_enlarge_px,
save_tta_fundamentals, tta_cache_dir,
bootstrap_samples, bootstrap_seed, first_n, last_n, range_start, range_count,
random_n, random_n_with_replacement, random_seed,
dry_run):
"""Evaluate a trained model on test data.
Supports JSON config files with CLI parameter overrides.
Examples:
# With JSON config
hae eval --config config/eval_250923_MVTec.json
# Config + CLI overrides
hae eval --config config/eval.json --batch-size 4 --num-workers 0
# Pure CLI (no config)
hae eval --model-path models/bottle.pth --dataset mvtec --category bottle
"""
# 1. Load config if provided
config_dict = {}
if config:
from dioodmi.cli.config import load_config_with_env, detect_environment
from pathlib import Path
config_dict = load_config_with_env(Path(config))
console.print(f"[cyan]📄 Loaded config:[/cyan] {config}")
console.print(f" Environment: {detect_environment()}")
# 2. Build parameter dictionary from CLI arguments
cli_params = {
'model_path': model_path,
'dataset': dataset,
'object_category': category,
'anomaly_class': anomaly_class,
'data_dir': data_dir,
'csv_split_file': csv_split_file,
'model_size': model_size,
'vae_type': vae_type,
'image_size': image_size,
'crop_size': crop_size,
'center_crop': center_crop,
'augmentation_strategy': augmentation_strategy,
'image_loading_strategy': image_loading_strategy,
'reverse_steps': reverse_steps,
'batch_size': batch_size,
'num_workers': num_workers,
'anomaly_threshold': anomaly_threshold,
'anomaly_min_area': anomaly_min_area,
'enable_confusion_matrix': enable_confusion_matrix,
'patch_size': patch_size,
'anomaly_binary_threshold': anomaly_binary_threshold,
'anomaly_pixel_num_threshold': anomaly_pixel_num_threshold,
'annotation_dir': annotation_dir,
'save_plot_path': output_dir,
'save_evaluation_images': save_evaluation_images,
'annotation_dir': annotation_dir,
'save_reconstructions': save_reconstructions,
'save_reconstructions_npy': save_reconstructions_npy,
'save_raw_anomaly_maps': save_raw_anomaly_maps,
'save_image_variants': list(save_image_variants) if save_image_variants else None,
'save_colormaps': list(save_colormaps) if save_colormaps else None,
'save_normalization': save_normalization,
'normal_center': normal_center,
'save_binary_thresholds': list(save_binary_thresholds) if save_binary_thresholds else None,
'save_intermediate_images': save_intermediate_images,
'save_npy_files': save_npy_files,
'npy_save_dir': npy_save_dir,
'full_image_eval': full_image_eval,
'tile_overlap': tile_overlap,
'tile_batch': tile_batch,
'pad_px': pad_px,
'stride': stride,
'directions': directions,
'shift_method': shift_method,
'fuse_method': fuse_method,
'img_enlarge_px': img_enlarge_px,
'save_tta_fundamentals': save_tta_fundamentals,
'tta_cache_dir': tta_cache_dir,
# Sampling parameters
'bootstrap_samples': bootstrap_samples,
'bootstrap_seed': bootstrap_seed,
'first_n': first_n,
'last_n': last_n,
'range_start': range_start,
'range_count': range_count,
'random_n': random_n,
'random_n_with_replacement': random_n_with_replacement,
'random_seed': random_seed,
'use_consolidated_npz': use_consolidated_npz,
}
# 3. Merge: config values first, then override with CLI args
final_params = config_dict.copy()
for key, value in cli_params.items():
if value is not None: # CLI arg was explicitly provided
final_params[key] = value
# 4. Validate required parameters
if 'model_path' not in final_params or final_params['model_path'] is None:
console.print("[red]❌ Error: --model-path required (or specify in config)[/red]")
sys.exit(1)
# Display evaluation info
eval_dataset = final_params.get('dataset', 'mvtec')
eval_category = final_params.get('object_category', '')
console.print(f"[bold blue]📊 Evaluating model on {eval_dataset}" +
(f"/{eval_category}" if eval_category else "") + "[/bold blue]\n")
# 5. Build platform-specific command
if sys.platform.startswith('win'):
cmd = ["py", "-3.11", "-m", "dioodmi.cli.eval"]
else:
cmd = ["dioodmi-eval"]
# 6. Add all parameters to command
cmd.extend(["--model_path", str(final_params['model_path'])])
# Dataset parameters
if 'dataset' in final_params:
cmd.extend(["--dataset", final_params['dataset']])
if 'object_category' in final_params:
cmd.extend(["--object_category", final_params['object_category']])
if 'anomaly_class' in final_params:
cmd.extend(["--anomaly_class", final_params['anomaly_class']])
if 'data_dir' in final_params:
cmd.extend(["--data_dir", final_params['data_dir']])
if 'csv_split_file' in final_params:
cmd.extend(["--csv_split_file", final_params['csv_split_file']])
# Model configuration
if 'model_size' in final_params:
cmd.extend(["--model_size", final_params['model_size']])
if 'vae_type' in final_params:
cmd.extend(["--vae_type", final_params['vae_type']])
if 'image_size' in final_params:
cmd.extend(["--image_size", str(final_params['image_size'])])
if 'crop_size' in final_params:
cmd.extend(["--crop_size", str(final_params['crop_size'])])
# Boolean flags (convert to true/false strings for dioodmi.cli.eval)
if 'center_crop' in final_params and final_params['center_crop']:
cmd.extend(["--center_crop", "true"])
# Strategy parameters
if 'augmentation_strategy' in final_params and final_params['augmentation_strategy']:
cmd.extend(["--augmentation_strategy", final_params['augmentation_strategy']])
if 'image_loading_strategy' in final_params and final_params['image_loading_strategy']:
cmd.extend(["--image_loading_strategy", final_params['image_loading_strategy']])
if 'enable_confusion_matrix' in final_params and final_params['enable_confusion_matrix'] is not None:
cmd.extend(["--enable_confusion_matrix",
"true" if final_params['enable_confusion_matrix'] else "false"])
if 'save_evaluation_images' in final_params and final_params['save_evaluation_images'] is not None:
cmd.extend(["--save_evaluation_images",
"true" if final_params['save_evaluation_images'] else "false"])
if 'save_reconstructions' in final_params and final_params['save_reconstructions'] is not None:
cmd.extend(["--save_reconstructions",
"true" if final_params['save_reconstructions'] else "false"])
if 'save_reconstructions_npy' in final_params and final_params['save_reconstructions_npy'] is not None:
cmd.extend(["--save_reconstructions_npy",
"true" if final_params['save_reconstructions_npy'] else "false"])
if 'save_raw_anomaly_maps' in final_params and final_params['save_raw_anomaly_maps'] is not None:
cmd.extend(["--save_raw_anomaly_maps",
"true" if final_params['save_raw_anomaly_maps'] else "false"])
if 'save_image_variants' in final_params and final_params['save_image_variants']:
cmd.extend(["--save-image-variants"] + [str(v) for v in final_params['save_image_variants']])
if 'save_colormaps' in final_params and final_params['save_colormaps']:
cmd.extend(["--save-colormaps"] + [str(v) for v in final_params['save_colormaps']])
if 'save_normalization' in final_params and final_params['save_normalization']:
cmd.extend(["--save-normalization", str(final_params['save_normalization'])])
if 'normal_center' in final_params and final_params['normal_center'] is not None:
cmd.extend(["--normal-center", str(final_params['normal_center'])])
if 'save_binary_thresholds' in final_params and final_params['save_binary_thresholds']:
cmd.extend(["--save-binary-thresholds"] + [str(v) for v in final_params['save_binary_thresholds']])
if 'save_intermediate_images' in final_params and final_params['save_intermediate_images'] is not None:
cmd.extend(["--save_intermediate_images",
"true" if final_params['save_intermediate_images'] else "false"])
if 'save_npy_files' in final_params and final_params['save_npy_files']:
cmd.extend(["--save_npy_files", "true"])
if 'full_image_eval' in final_params and final_params['full_image_eval'] is not None:
cmd.extend(["--full_image_eval",
"true" if final_params['full_image_eval'] else "false"])
# Evaluation parameters
if 'reverse_steps' in final_params:
cmd.extend(["--reverse_steps", str(final_params['reverse_steps'])])
if 'batch_size' in final_params:
cmd.extend(["--batch_size", str(final_params['batch_size'])])
if 'num_workers' in final_params:
cmd.extend(["--num_workers", str(final_params['num_workers'])])
# Anomaly detection parameters
if 'anomaly_threshold' in final_params:
cmd.extend(["--anomaly_threshold", str(final_params['anomaly_threshold'])])
if 'anomaly_min_area' in final_params:
cmd.extend(["--anomaly_min_area", str(final_params['anomaly_min_area'])])
# Confusion matrix parameters
if 'patch_size' in final_params:
cmd.extend(["--patch_size", str(final_params['patch_size'])])
if 'anomaly_binary_threshold' in final_params:
cmd.extend(["--anomaly_binary_threshold", str(final_params['anomaly_binary_threshold'])])
if 'anomaly_pixel_num_threshold' in final_params:
cmd.extend(["--anomaly_pixel_num_threshold", str(final_params['anomaly_pixel_num_threshold'])])
if 'annotation_dir' in final_params:
cmd.extend(["--annotation_dir", final_params['annotation_dir']])
# Output parameters
if 'save_plot_path' in final_params:
cmd.extend(["--save_plot_path", final_params['save_plot_path']])
if 'npy_save_dir' in final_params:
cmd.extend(["--npy_save_dir", final_params['npy_save_dir']])
# Full image evaluation parameters
if 'tile_overlap' in final_params:
cmd.extend(["--tile_overlap", str(final_params['tile_overlap'])])
if 'tile_batch' in final_params:
cmd.extend(["--tile_batch", str(final_params['tile_batch'])])
# Shift augmentation parameters (only pass if pad_px is set, indicating TTA is enabled)
if 'pad_px' in final_params and final_params['pad_px'] is not None:
cmd.extend(["--pad_px", str(final_params['pad_px'])])
if 'img_enlarge_px' in final_params:
cmd.extend(["--img_enlarge_px", str(final_params['img_enlarge_px'])])
if 'stride' in final_params:
cmd.extend(["--stride", str(final_params['stride'])])
if 'directions' in final_params and final_params['directions']:
# dioodmi-eval uses argparse nargs="+" so it expects: --directions h v diag
directions = final_params['directions']
if isinstance(directions, (tuple, list)):
cmd.extend(["--directions"] + list(directions))
else:
# Handle string case (from config file): "h v diag" -> ["h", "v", "diag"]
cmd.extend(["--directions"] + directions.split())
if 'shift_method' in final_params:
cmd.extend(["--shift_method", final_params['shift_method']])
if 'fuse_method' in final_params:
cmd.extend(["--fuse_method", final_params['fuse_method']])
# TTA fundamentals saving
if 'save_tta_fundamentals' in final_params and final_params['save_tta_fundamentals']:
cmd.append("--save_tta_fundamentals")
if 'tta_cache_dir' in final_params and final_params['tta_cache_dir']:
cmd.extend(["--tta_cache_dir", str(final_params['tta_cache_dir'])])
# Dataset sampling parameters
if 'bootstrap_samples' in final_params and final_params['bootstrap_samples'] is not None:
cmd.extend(["--bootstrap_samples", str(final_params['bootstrap_samples'])])
if 'bootstrap_seed' in final_params:
cmd.extend(["--bootstrap_seed", str(final_params['bootstrap_seed'])])
if 'first_n' in final_params and final_params['first_n'] is not None:
cmd.extend(["--first_n", str(final_params['first_n'])])
if 'last_n' in final_params and final_params['last_n'] is not None:
cmd.extend(["--last_n", str(final_params['last_n'])])
if 'range_start' in final_params and final_params['range_start'] is not None:
cmd.extend(["--range_start", str(final_params['range_start'])])
if 'range_count' in final_params:
cmd.extend(["--range_count", str(final_params['range_count'])])
if 'random_n' in final_params and final_params['random_n'] is not None:
cmd.extend(["--random_n", str(final_params['random_n'])])
if 'random_n_with_replacement' in final_params and final_params['random_n_with_replacement'] is not None:
cmd.extend(["--random_n_with_replacement", str(final_params['random_n_with_replacement'])])
if 'random_seed' in final_params:
cmd.extend(["--random_seed", str(final_params['random_seed'])])
# 7. Show standalone alternative
show_standalone_alternatives(cmd, f"Evaluate on {eval_dataset}" +
(f"/{eval_category}" if eval_category else ""))
# 8. Execute
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.command('anomaly-detect')
# Config file option (FIRST - allows loading all parameters from JSON)
@click.option('--config', type=click.Path(exists=True),
help='JSON config file (base + environment overrides)')
# Mode selection
@click.option('--mode', type=click.Choice(['detect_and_evaluate_without_cache', 'detect_and_evaluate',
'detect_only', 'evaluate_only']),
default='detect_and_evaluate',
help='Execution mode (default: detect_and_evaluate)')
@click.option('--npy-cache-dir', type=click.Path(),
help='Directory for NPY cache files (defaults to {save_plot_path}/npy_cache)')
# Core parameters
@click.option('--model-path', '--checkpoint', 'model_path',
type=click.Path(exists=True),
help='Path to trained model checkpoint')
@click.option('--dataset', type=click.Choice(['mvtec', 'mvtec2', 'sdp', 'sdp_orig', 'sdp_orig_test',
'sdp_ok_ng', 'sdp_ok_ng_crop', 'sdp_ok_ng_resize_1024',
'sdp_ok_ng_resize_2048', 'visa', 'pcb', 'avp', 'iris']),
default='mvtec',
help='Dataset to evaluate')
@click.option('--object-category', '--category', 'category',
default='all',
help='Dataset category')
@click.option('--anomaly-class',
default='all',
help='Anomaly class filter')
@click.option('--data-dir', type=click.Path(exists=True),
help='Path to dataset directory')
@click.option('--csv-split-file',
help='Custom CSV split file (overrides dataset default)')
# Model configuration
@click.option('--model-size', type=click.Choice(['UNet_XS','UNet_S','UNet_M','UNet_L','UNet_XL']),
default='UNet_L',
help='Model size')
@click.option('--vae-type', type=click.Choice(['ema', 'mse']),
default='ema',
help='VAE type (ema or mse)')
@click.option('--image-size', type=str,
help='Image size (integer or "None" for dataset default)')
@click.option('--crop-size', type=int, default=256,
help='Crop size')
@click.option('--center-crop', is_flag=True, default=True,
help='Enable center crop')
# Strategy parameters
@click.option('--augmentation-strategy',
type=click.Choice(['center_crop', 'center_crop_with_affine',
'random_crop', 'multi_stage', 'none']),
help='Augmentation strategy for evaluation')
@click.option('--image-loading-strategy',
type=click.Choice(['resize_first', 'keep_original', 'adaptive']),
help='Image loading strategy (preprocessing before augmentation)')
# Evaluation parameters
@click.option('--reverse-steps', type=int, default=5,
help='Number of reverse diffusion steps')
@click.option('--batch-size', type=int, default=1,
help='DataLoader batch size (process B images together)')
@click.option('--num-workers', type=int, default=4,
help='Number of DataLoader workers')
# TTA parameters (batched version - AnomalyDetector specific)
@click.option('--pad-px', type=int, default=None,
help='Shift range for TTA (±pad_px pixels). Set to enable TTA, None to disable')
@click.option('--img-enlarge-px', type=int, default=4,
help='Image enlargement for interpolation method')
@click.option('--stride', type=int, default=1,
help='Shift increment (1=every pixel, 2=skip pixels)')
@click.option('--directions', multiple=True, callback=parse_directions,
help='Shift directions (space-separated: h v diag all rhombus, or use multiple --directions flags)')
@click.option('--shift-method', type=click.Choice(['interpolate', 'mirror']),
default='interpolate',
help='Shift augmentation method')
@click.option('--fuse-method', type=click.Choice(['mean', 'median', 'lowest', 'pct25', 'pct75']),
default='mean',
help='Method to fuse multiple shifted predictions')
@click.option('--tta-batch-size', type=int, default=8,
help='Batch size for TTA shift processing (combined with DataLoader batch)')
# Anomaly detection parameters
@click.option('--anomaly-threshold', type=float, default=20,
help='Threshold for anomaly detection (0-255 range, for full_image_eval)')
@click.option('--anomaly-min-area', type=int, default=10,
help='Minimum area for detected anomaly regions (for full_image_eval)')
# Output parameters
@click.option('--save-plot-path', '--output-dir', 'output_dir',
type=click.Path(),
help='Results output directory')
@click.option('--save-evaluation-images', type=str, default='false',
help='Save evaluation images in organized subfolders (true/false)')
@click.option('--annotation-dir', type=click.Path(),
help='Directory containing annotation JSON files for TP/FP/FN/TN classification')
@click.option('--save-reconstructions', is_flag=True, default=None,
help='Save reconstructed images (VAE and Deco-Diff) as PNG')
@click.option('--save-reconstructions-npy', is_flag=True, default=None,
help='Save raw reconstruction arrays as NPY files')
@click.option('--save-raw-anomaly-maps', is_flag=True, default=None,
help='Save raw anomaly map arrays as NPY files (image-level)')
@click.option('--save-image-variants', type=str, multiple=True,
help='Which image variants to save (can specify multiple: continuous, binary, grayscale, absolute, all)')
@click.option('--save-colormaps', type=str, multiple=True,
help='Colormaps to use for continuous maps (can specify multiple: jet, hot, viridis, plasma, magma, inferno, turbo)')
@click.option('--save-normalization', type=click.Choice(['minmax', 'zscore', 'percentile', 'raw']),
help='Normalization method for visualization')
@click.option('--normal-center', type=float,
help='Center value for grayscale normalization (default: 125.0)')
@click.option('--save-binary-thresholds', type=float, multiple=True,
help='Threshold values for binary maps (can specify multiple, e.g., 5.0 10.0 15.0)')
@click.option('--checkpoint-enabled', is_flag=True, default=False,
help='Enable checkpoint functionality for image saving')
# NPY file saving (for regression testing)
@click.option('--save-npy-files', is_flag=True, default=False,
help='Save anomaly maps as .npy files (for regression testing)')
@click.option('--npy-save-dir', type=click.Path(),
help='Directory to save .npy files')
# Filename generation strategy
@click.option('--filename-strategy', type=click.Choice(['auto', 'simple', 'anomaly_type', 'full_path',
'incremental', 'dataset_index']),
default='auto',
help='Filename generation strategy')
@click.option('--regression-test-mode', is_flag=True, default=False,
help='Enable regression test mode (forces simple filename strategy)')
# Dataset sampling options
@click.option('--bootstrap-samples', type=int, default=None,
help='Bootstrap sampling: number of samples (with replacement)')
@click.option('--bootstrap-seed', type=int, default=42,
help='Random seed for bootstrap sampling')
@click.option('--first-n', type=int, default=None,
help='Take first N samples from dataset')
@click.option('--last-n', type=int, default=None,
help='Take last N samples from dataset')
@click.option('--range-start', type=int, default=None,
help='Range sampling: start index (0-based)')
@click.option('--range-count', type=int, default=10,
help='Range sampling: number of samples (default: 10)')
@click.option('--random-n', type=int, default=None,
help='Random N samples without replacement')
@click.option('--random-n-with-replacement', type=int, default=None,
help='Random N samples with replacement')
@click.option('--random-seed', type=int, default=42,
help='Random seed for random sampling methods')
# I/O optimization parameters
@click.option('--use_consolidated_npz', is_flag=True, default=False,
help='Use single compressed .npz file per image instead of multiple .npy files (4-6× faster I/O)')
# Utility
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
def anomaly_detect(config, mode, npy_cache_dir, model_path, dataset, category, anomaly_class, data_dir, csv_split_file,
model_size, vae_type, image_size, crop_size, center_crop,
augmentation_strategy, image_loading_strategy,
reverse_steps, batch_size, num_workers,
pad_px, img_enlarge_px, stride, directions, shift_method, fuse_method, tta_batch_size,
anomaly_threshold, anomaly_min_area,
output_dir, save_evaluation_images, annotation_dir, save_reconstructions, save_reconstructions_npy,
save_raw_anomaly_maps, save_image_variants, save_colormaps, save_normalization,
normal_center, save_binary_thresholds, checkpoint_enabled,
save_npy_files, npy_save_dir,
filename_strategy, regression_test_mode,
bootstrap_samples, bootstrap_seed, first_n, last_n, range_start, range_count,
random_n, random_n_with_replacement, random_seed,
use_consolidated_npz,
dry_run):
"""Run anomaly detection using modern AnomalyDetector (consolidated evaluator).
Uses the modern AnomalyDetector implementation which consolidates DecodiffEvaluator
and DecodiffEvaluateProcessor with efficient batched TTA (~6.6× speedup).
Features:
- Batched TTA for ~6.6× speedup vs sequential TTA
- Configurable shifts: --pad-px, --stride, --directions
- Clean architecture without legacy code
Examples:
# Basic evaluation
hae anomaly-detect --model-path models/bottle.pth --dataset mvtec --category bottle
# With TTA enabled
hae anomaly-detect --model-path models/bottle.pth --pad-px 4 --tta-batch-size 8
# With config file
hae anomaly-detect --config config/eval.json --pad-px 4
"""
# 1. Load config if provided
config_dict = {}
if config:
from dioodmi.cli.config import load_config_with_env, detect_environment
from pathlib import Path
config_dict = load_config_with_env(Path(config))
console.print(f"[cyan]📄 Loaded config:[/cyan] {config}")
console.print(f" Environment: {detect_environment()}")
# 2. Build parameter dictionary from CLI arguments
cli_params = {
'mode': mode,
'npy_cache_dir': npy_cache_dir,
'model_path': model_path,
'dataset': dataset,
'object_category': category,
'anomaly_class': anomaly_class,
'data_dir': data_dir,
'csv_split_file': csv_split_file,
'model_size': model_size,
'vae_type': vae_type,
'image_size': image_size,
'crop_size': crop_size,
'center_crop': center_crop,
'augmentation_strategy': augmentation_strategy,
'image_loading_strategy': image_loading_strategy,
'reverse_steps': reverse_steps,
'batch_size': batch_size,
'num_workers': num_workers,
'pad_px': pad_px,
'img_enlarge_px': img_enlarge_px,
'stride': stride,
'directions': directions,
'shift_method': shift_method,
'fuse_method': fuse_method,
'tta_batch_size': tta_batch_size,
'anomaly_threshold': anomaly_threshold,
'anomaly_min_area': anomaly_min_area,
'save_plot_path': output_dir,
'save_evaluation_images': save_evaluation_images,
'annotation_dir': annotation_dir,
'save_reconstructions': save_reconstructions,
'save_reconstructions_npy': save_reconstructions_npy,
'save_raw_anomaly_maps': save_raw_anomaly_maps,
'save_image_variants': list(save_image_variants) if save_image_variants else None,
'save_colormaps': list(save_colormaps) if save_colormaps else None,
'save_normalization': save_normalization,
'normal_center': normal_center,
'save_binary_thresholds': list(save_binary_thresholds) if save_binary_thresholds else None,
'checkpoint_enabled': checkpoint_enabled,
'save_npy_files': save_npy_files,
'npy_save_dir': npy_save_dir,
'filename_strategy': filename_strategy,
'regression_test_mode': regression_test_mode,
# Sampling parameters
'bootstrap_samples': bootstrap_samples,
'bootstrap_seed': bootstrap_seed,
'first_n': first_n,
'last_n': last_n,
'range_start': range_start,
'range_count': range_count,
'random_n': random_n,
'random_n_with_replacement': random_n_with_replacement,
'random_seed': random_seed,
'use_consolidated_npz': use_consolidated_npz,
}
# 3. Merge: config values first, then override with CLI args
final_params = config_dict.copy()
for key, value in cli_params.items():
if value is not None: # CLI arg was explicitly provided
final_params[key] = value
# 4. Validate required parameters
if 'model_path' not in final_params or final_params['model_path'] is None:
console.print("[red]❌ Error: --model-path required (or specify in config)[/red]")
sys.exit(1)
# Display evaluation info
eval_dataset = final_params.get('dataset', 'mvtec')
eval_category = final_params.get('object_category', 'all')
console.print(f"[bold green]🔍 Anomaly Detection (Modern Version)[/bold green]")
console.print(f"[bold]Dataset:[/bold] {eval_dataset}" +
(f"/{eval_category}" if eval_category != 'all' else ""))
console.print(f"[bold]Model:[/bold] {final_params.get('model_size', 'UNet_L')}\n")
# Show TTA status
if final_params.get('pad_px') is not None:
directions_display = final_params.get('directions', ('h', 'v'))
if isinstance(directions_display, tuple):
directions_display = ' '.join(directions_display)
console.print(f"[cyan]✓ TTA Enabled:[/cyan]")
console.print(f" Shift range: ±{final_params['pad_px']} pixels")
console.print(f" Stride: {final_params.get('stride', 1)}")
console.print(f" Directions: {directions_display}")
console.print(f" TTA batch size: {final_params.get('tta_batch_size', 8)}")
console.print(f" Expected speedup: ~6.6× vs sequential TTA\n")
else:
console.print(f"[yellow]✗ TTA Disabled (use --pad-px to enable)[/yellow]\n")
# 5. Build platform-specific command
if sys.platform.startswith('win'):
cmd = ["py", "-3.11", "-m", "dioodmi.cli.ad"]
else:
cmd = ["dioodmi-ad"]
# 6. Add all parameters to command
# Mode selection (must come first for positional parsing)
if 'mode' in final_params:
cmd.extend(["--mode", final_params['mode']])
if 'npy_cache_dir' in final_params and final_params['npy_cache_dir']:
cmd.extend(["--npy_cache_dir", final_params['npy_cache_dir']])
cmd.extend(["--model_path", str(final_params['model_path'])])
# Dataset parameters
if 'dataset' in final_params:
cmd.extend(["--dataset", final_params['dataset']])
if 'object_category' in final_params:
cmd.extend(["--object_category", final_params['object_category']])
if 'anomaly_class' in final_params:
cmd.extend(["--anomaly_class", final_params['anomaly_class']])
if 'data_dir' in final_params:
cmd.extend(["--data_dir", final_params['data_dir']])
if 'csv_split_file' in final_params:
cmd.extend(["--csv_split_file", final_params['csv_split_file']])
# Model configuration
if 'model_size' in final_params:
cmd.extend(["--model_size", final_params['model_size']])
if 'vae_type' in final_params:
cmd.extend(["--vae_type", final_params['vae_type']])
if 'image_size' in final_params:
if final_params['image_size'] is None:
cmd.extend(["--image_size", "None"])
else:
cmd.extend(["--image_size", str(final_params['image_size'])])
if 'crop_size' in final_params:
cmd.extend(["--crop_size", str(final_params['crop_size'])])
if 'center_crop' in final_params:
cmd.extend(["--center_crop", "true" if final_params['center_crop'] else "false"])
# Strategy parameters
if 'augmentation_strategy' in final_params and final_params['augmentation_strategy']:
cmd.extend(["--augmentation_strategy", final_params['augmentation_strategy']])
if 'image_loading_strategy' in final_params and final_params['image_loading_strategy']:
cmd.extend(["--image_loading_strategy", final_params['image_loading_strategy']])
# Evaluation parameters
if 'reverse_steps' in final_params:
cmd.extend(["--reverse_steps", str(final_params['reverse_steps'])])
if 'batch_size' in final_params:
cmd.extend(["--batch_size", str(final_params['batch_size'])])
if 'num_workers' in final_params:
cmd.extend(["--num_workers", str(final_params['num_workers'])])
# TTA parameters (only pass if pad_px is set, indicating TTA is enabled)
if 'pad_px' in final_params and final_params['pad_px'] is not None:
cmd.extend(["--pad_px", str(final_params['pad_px'])])
if 'img_enlarge_px' in final_params:
cmd.extend(["--img_enlarge_px", str(final_params['img_enlarge_px'])])
if 'stride' in final_params:
cmd.extend(["--stride", str(final_params['stride'])])
if 'directions' in final_params and final_params['directions']:
# dioodmi-ad uses argparse nargs="+" so it expects: --directions h v diag
directions = final_params['directions']
if isinstance(directions, (tuple, list)):
cmd.extend(["--directions"] + list(directions))
else:
# Handle string case (from config file): "h v diag" -> ["h", "v", "diag"]
cmd.extend(["--directions"] + directions.split())
if 'shift_method' in final_params:
cmd.extend(["--shift_method", final_params['shift_method']])
if 'fuse_method' in final_params:
cmd.extend(["--fuse_method", final_params['fuse_method']])
if 'tta_batch_size' in final_params:
cmd.extend(["--tta_batch_size", str(final_params['tta_batch_size'])])
# Anomaly detection parameters
if 'anomaly_threshold' in final_params and final_params['anomaly_threshold'] is not None:
cmd.extend(["--anomaly_threshold", str(final_params['anomaly_threshold'])])
if 'anomaly_min_area' in final_params and final_params['anomaly_min_area'] is not None:
cmd.extend(["--anomaly_min_area", str(final_params['anomaly_min_area'])])
# Output parameters
if 'save_plot_path' in final_params:
cmd.extend(["--save_plot_path", final_params['save_plot_path']])
if 'save_evaluation_images' in final_params:
# Convert to boolean and pass as string value
save_eval_bool = str(final_params['save_evaluation_images']).lower() in ('true', '1', 'yes')
cmd.extend(["--save_evaluation_images", "true" if save_eval_bool else "false"])
if 'annotation_dir' in final_params and final_params['annotation_dir'] is not None:
cmd.extend(["--annotation_dir", final_params['annotation_dir']])
if 'save_reconstructions' in final_params and final_params['save_reconstructions'] is not None:
cmd.extend(["--save_reconstructions", "true" if final_params['save_reconstructions'] else "false"])
if 'save_reconstructions_npy' in final_params and final_params['save_reconstructions_npy'] is not None:
cmd.extend(["--save_reconstructions_npy", "true" if final_params['save_reconstructions_npy'] else "false"])
if 'save_raw_anomaly_maps' in final_params and final_params['save_raw_anomaly_maps'] is not None:
cmd.extend(["--save_raw_anomaly_maps", "true" if final_params['save_raw_anomaly_maps'] else "false"])
if 'save_image_variants' in final_params and final_params['save_image_variants']:
cmd.extend(["--save-image-variants"] + [str(v) for v in final_params['save_image_variants']])
if 'save_colormaps' in final_params and final_params['save_colormaps']:
cmd.extend(["--save-colormaps"] + [str(v) for v in final_params['save_colormaps']])
if 'save_normalization' in final_params and final_params['save_normalization']:
cmd.extend(["--save-normalization", str(final_params['save_normalization'])])
if 'normal_center' in final_params and final_params['normal_center'] is not None:
cmd.extend(["--normal-center", str(final_params['normal_center'])])
if 'save_binary_thresholds' in final_params and final_params['save_binary_thresholds']:
cmd.extend(["--save-binary-thresholds"] + [str(v) for v in final_params['save_binary_thresholds']])
if 'checkpoint_enabled' in final_params and final_params['checkpoint_enabled'] is not None:
cmd.extend(["--checkpoint_enabled", "true" if final_params['checkpoint_enabled'] else "false"])
# NPY file saving (for regression testing)
if 'save_npy_files' in final_params and final_params['save_npy_files']:
cmd.extend(["--save_npy_files", "true"])
if 'npy_save_dir' in final_params and final_params['npy_save_dir']:
cmd.extend(["--npy_save_dir", final_params['npy_save_dir']])
# Filename strategy
if 'filename_strategy' in final_params:
cmd.extend(["--filename_strategy", final_params['filename_strategy']])
if 'regression_test_mode' in final_params and final_params['regression_test_mode']:
cmd.append("--regression_test_mode")
# Sampling parameters
if 'bootstrap_samples' in final_params and final_params['bootstrap_samples'] is not None:
cmd.extend(["--bootstrap-samples", str(final_params['bootstrap_samples'])])
if 'bootstrap_seed' in final_params:
cmd.extend(["--bootstrap-seed", str(final_params['bootstrap_seed'])])
if 'first_n' in final_params and final_params['first_n'] is not None:
cmd.extend(["--first-n", str(final_params['first_n'])])
if 'last_n' in final_params and final_params['last_n'] is not None:
cmd.extend(["--last-n", str(final_params['last_n'])])
if 'range_start' in final_params and final_params['range_start'] is not None:
cmd.extend(["--range-start", str(final_params['range_start'])])
if 'range_count' in final_params:
cmd.extend(["--range-count", str(final_params['range_count'])])
if 'random_n' in final_params and final_params['random_n'] is not None:
cmd.extend(["--random-n", str(final_params['random_n'])])
if 'random_seed' in final_params:
cmd.extend(["--random-seed", str(final_params['random_seed'])])
if 'random_n_with_replacement' in final_params and final_params['random_n_with_replacement'] is not None:
cmd.extend(["--random-n-with-replacement", str(final_params['random_n_with_replacement'])])
if 'random_seed' in final_params:
cmd.extend(["--random-seed", str(final_params['random_seed'])])
# I/O optimization parameters
if 'use_consolidated_npz' in final_params and final_params['use_consolidated_npz']:
cmd.extend(["--use_consolidated_npz", "true"])
# 7. Show standalone alternative
show_standalone_alternatives(cmd, f"Anomaly Detection on {eval_dataset}" +
(f"/{eval_category}" if eval_category != 'all' else ""))
# 8. Execute
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.command('eval-process')
# Config file option (FIRST - allows loading all parameters from JSON)
@click.option('--config', type=click.Path(exists=True),
help='JSON config file (base + environment overrides)')
# Mode selection (REQUIRED)
@click.option('--mode', type=click.Choice(['save_only', 'process_only', 'save_and_process',
'full_pipeline', 'full_pipeline_with_saving_npy',
'save_and_process_with_bootstrap']),
help='Execution mode (REQUIRED unless specified in config)')
# Core parameters
@click.option('--pretrained', '--model-path', 'pretrained',
type=click.Path(exists=True),
help='Path to trained model checkpoint (REQUIRED for save modes)')
@click.option('--annotation-dir', type=click.Path(),
help='Directory containing annotation files (REQUIRED)')
@click.option('--dataset', type=click.Choice(['mvtec', 'mvtec2', 'sdp', 'sdp_orig', 'sdp_orig_test',
'sdp_ok_ng', 'sdp_ok_ng_crop', 'sdp_ok_ng_resize_1024',
'sdp_ok_ng_resize_2048', 'visa', 'pcb', 'avp', 'iris']),
default='mvtec',
help='Dataset to evaluate')
@click.option('--object-class', default='all',
help='Object class to process')
@click.option('--anomaly-class',
help='Anomaly class filter')
@click.option('--data-dir', type=click.Path(exists=True),
help='Path to dataset directory')
@click.option('--csv-split-file',
help='Custom CSV split file (overrides dataset default)')
# Model configuration
@click.option('--model-size', type=click.Choice(['UNet_XS','UNet_S','UNet_M','UNet_L','UNet_XL']),
default='UNet_L',
help='Model size')
@click.option('--vae-type', type=click.Choice(['ema', 'mse']),
default='ema',
help='VAE type (ema or mse)')
@click.option('--image-size', type=int,
help='Image size (integer)')
@click.option('--center-size', type=int,
help='Center size')
@click.option('--center-crop', type=str,
help='Center crop (string)')
# Processing parameters
@click.option('--patch-size', type=int, default=128,
help='Patch size for image processing')
@click.option('--stride', type=int,
help='Stride for patch extraction (None = no overlap)')
@click.option('--irregular-patch', is_flag=True, default=False,
help='Use irregular patch for image processing')
@click.option('--reverse-steps', type=int, default=5,
help='Number of reverse diffusion steps')
@click.option('--batch-size', type=int, default=64,
help='Batch size for processing')
@click.option('--batch-num', type=int, default=12,
help='Number of batches to process')
@click.option('--num-workers', type=int, default=0,
help='Number of DataLoader workers')
@click.option('--split', type=str, default='test',
help='Data split to process')
# Anomaly detection parameters
@click.option('--anomaly-binary-threshold', type=int, default=5,
help='Binary threshold for anomaly detection')
@click.option('--anomaly-pixel-num-threshold', type=int, default=0,
help='Pixel number threshold')
@click.option('--adaptive-threshold', type=float, default=0.1,
help='Adaptive threshold for contour-based masks')
# Output parameters
@click.option('--results-dir', type=click.Path(),
help='Results directory (default: auto-generated)')
@click.option('--tag', type=str,
help='Custom tag for output directory')
@click.option('--enable-excel-report', is_flag=True, default=False,
help='Generate Excel report')
@click.option('--enable-save-image-results', is_flag=True, default=False,
help='Save image results')
@click.option('--enable-save-whole-image-results', is_flag=True, default=False,
help='Save whole image results')
@click.option('--save-reconstructions', is_flag=True, default=None,
help='Save reconstructed images (VAE and Deco-Diff) as PNG')
@click.option('--save-reconstructions-npy', is_flag=True, default=None,
help='Save raw reconstruction arrays as NPY files')
@click.option('--save-raw-anomaly-maps', is_flag=True, default=None,
help='Save raw anomaly map arrays as NPY files (image-level)')
@click.option('--save-image-variants', type=str, multiple=True,
help='Which image variants to save (can specify multiple: continuous, binary, grayscale, absolute, all)')
@click.option('--save-colormaps', type=str, multiple=True,
help='Colormaps to use for continuous maps (can specify multiple: jet, hot, viridis, plasma, magma, inferno, turbo)')
@click.option('--save-normalization', type=click.Choice(['minmax', 'zscore', 'percentile', 'raw']),
help='Normalization method for visualization')
@click.option('--normal-center', type=float,
help='Center value for grayscale normalization (default: 125.0)')
@click.option('--save-binary-thresholds', type=float, multiple=True,
help='Threshold values for binary maps (can specify multiple, e.g., 5.0 10.0 15.0)')
@click.option('--enable-confusion-matrix', is_flag=True, default=False,
help='Create confusion matrix')
# Filename generation strategy
@click.option('--filename-strategy', type=click.Choice(['auto', 'simple', 'anomaly_type', 'full_path',
'incremental', 'dataset_index']),
default='auto',
help='Filename generation strategy')
@click.option('--regression-test-mode', is_flag=True, default=False,
help='Enable regression test mode')
# Bootstrap parameters
@click.option('--bootstrap-samples', type=int,
help='Enable bootstrap sampling: save only N randomly sampled patches')
@click.option('--bootstrap-seed', type=int, default=42,
help='Random seed for bootstrap sampling')
@click.option('--bootstrap-sampling-trial-number', type=int, default=1,
help='Number of bootstrap resampling trials')
# Utility
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
def eval_process(config, mode, pretrained, annotation_dir, dataset, object_class, anomaly_class,
data_dir, csv_split_file, model_size, vae_type, image_size, center_size,
center_crop, patch_size, stride, irregular_patch, reverse_steps, batch_size,
batch_num, num_workers, split, anomaly_binary_threshold,
anomaly_pixel_num_threshold, adaptive_threshold, results_dir, tag,
enable_excel_report, enable_save_image_results, enable_save_whole_image_results,
save_reconstructions, save_reconstructions_npy, save_raw_anomaly_maps,
save_image_variants, save_colormaps, save_normalization, normal_center,
save_binary_thresholds, enable_confusion_matrix,
filename_strategy, regression_test_mode, bootstrap_samples, bootstrap_seed,
bootstrap_sampling_trial_number, dry_run):
"""Run evaluation processing using DecodiffEvaluateProcessor.
Provides multiple execution modes for evaluation and processing:
- save_only: Save .npy files and diff images only
- process_only: Process existing .npy files for categorization
- save_and_process: Save and immediately process
- full_pipeline: Complete pipeline without saving intermediates
- full_pipeline_with_saving_npy: Complete pipeline with saving NPY files
- save_and_process_with_bootstrap: Bootstrap evaluation with two-stage sampling
Examples:
# Full pipeline mode
hae eval-process --mode full_pipeline --pretrained models/bottle.pth --annotation-dir ./annotations
# Process only existing NPY files
hae eval-process --mode process_only --annotation-dir ./annotations
# With config file
hae eval-process --config config/process.json --mode full_pipeline
"""
# 1. Load config if provided
config_dict = {}
if config:
from dioodmi.cli.config import load_config_with_env, detect_environment
from pathlib import Path
config_dict = load_config_with_env(Path(config))
console.print(f"[cyan]📄 Loaded config:[/cyan] {config}")
console.print(f" Environment: {detect_environment()}")
# 2. Build parameter dictionary from CLI arguments
cli_params = {
'mode': mode,
'pretrained': pretrained,
'annotation_dir': annotation_dir,
'dataset': dataset,
'object_class': object_class,
'anomaly_class': anomaly_class,
'data_dir': data_dir,
'csv_split_file': csv_split_file,
'model_size': model_size,
'vae_type': vae_type,
'image_size': image_size,
'center_size': center_size,
'center_crop': center_crop,
'patch_size': patch_size,
'stride': stride,
'irregular_patch': irregular_patch,
'reverse_steps': reverse_steps,
'batch_size': batch_size,
'batch_num': batch_num,
'num_workers': num_workers,
'split': split,
'anomaly_binary_threshold': anomaly_binary_threshold,
'anomaly_pixel_num_threshold': anomaly_pixel_num_threshold,
'adaptive_threshold': adaptive_threshold,
'results_dir': results_dir,
'tag': tag,
'enable_excel_report': enable_excel_report,
'enable_save_image_results': enable_save_image_results,
'enable_save_whole_image_results': enable_save_whole_image_results,
'save_reconstructions': save_reconstructions,
'save_reconstructions_npy': save_reconstructions_npy,
'save_raw_anomaly_maps': save_raw_anomaly_maps,
'save_image_variants': list(save_image_variants) if save_image_variants else None,
'save_colormaps': list(save_colormaps) if save_colormaps else None,
'save_normalization': save_normalization,
'normal_center': normal_center,
'save_binary_thresholds': list(save_binary_thresholds) if save_binary_thresholds else None,
'enable_confusion_matrix': enable_confusion_matrix,
'filename_strategy': filename_strategy,
'regression_test_mode': regression_test_mode,
'bootstrap_samples': bootstrap_samples,
'bootstrap_seed': bootstrap_seed,
'bootstrap_sampling_trial_number': bootstrap_sampling_trial_number,
}
# 3. Merge: config values first, then override with CLI args
final_params = config_dict.copy()
for key, value in cli_params.items():
if value is not None: # CLI arg was explicitly provided
final_params[key] = value
# 4. Validate required parameters
if 'mode' not in final_params or final_params['mode'] is None:
console.print("[red]❌ Error: --mode required (or specify in config)[/red]")
sys.exit(1)
if 'annotation_dir' not in final_params or final_params['annotation_dir'] is None:
console.print("[red]❌ Error: --annotation-dir required (or specify in config)[/red]")
sys.exit(1)
# Display evaluation info
eval_mode = final_params.get('mode', 'unknown')
eval_dataset = final_params.get('dataset', 'mvtec')
eval_object = final_params.get('object_class', 'all')
console.print(f"[bold cyan]⚙️ Evaluation Processing (DecodiffEvaluateProcessor)[/bold cyan]")
console.print(f"[bold]Mode:[/bold] {eval_mode}")
console.print(f"[bold]Dataset:[/bold] {eval_dataset}" +
(f"/{eval_object}" if eval_object != 'all' else ""))
console.print(f"[bold]Model:[/bold] {final_params.get('model_size', 'UNet_L')}\n")
# 5. Build platform-specific command
if sys.platform.startswith('win'):
cmd = ["py", "-3.11", "-m", "dioodmi.cli.process"]
else:
cmd = ["dioodmi-process"]
# 6. Add all parameters to command
cmd.extend(["--mode", final_params['mode']])
if 'pretrained' in final_params and final_params['pretrained']:
cmd.extend(["--pretrained", str(final_params['pretrained'])])
if 'annotation_dir' in final_params:
cmd.extend(["--annotation_dir", final_params['annotation_dir']])
if 'dataset' in final_params:
cmd.extend(["--dataset", final_params['dataset']])
if 'object_class' in final_params:
cmd.extend(["--object_class", final_params['object_class']])
if 'anomaly_class' in final_params:
cmd.extend(["--anomaly_class", final_params['anomaly_class']])
if 'data_dir' in final_params:
cmd.extend(["--data_dir", final_params['data_dir']])
if 'csv_split_file' in final_params:
cmd.extend(["--csv_split_file", final_params['csv_split_file']])
# Model configuration
if 'model_size' in final_params:
cmd.extend(["--model_size", final_params['model_size']])
if 'vae_type' in final_params:
cmd.extend(["--vae_type", final_params['vae_type']])
if 'image_size' in final_params:
cmd.extend(["--image_size", str(final_params['image_size'])])
if 'center_size' in final_params:
cmd.extend(["--center_size", str(final_params['center_size'])])
if 'center_crop' in final_params:
cmd.extend(["--center_crop", final_params['center_crop']])
# Processing parameters
if 'patch_size' in final_params:
cmd.extend(["--patch_size", str(final_params['patch_size'])])
if 'stride' in final_params and final_params['stride'] is not None:
# Handle string "null" or "None" from YAML substitution
# YAML null becomes Python None, which becomes string "None" when converted
stride_val = final_params['stride']
stride_str = str(stride_val).strip().lower()
if stride_str not in ('null', 'none', ''):
cmd.extend(["--stride", str(stride_val)])
if 'irregular_patch' in final_params and final_params['irregular_patch']:
cmd.append("--irregular_patch")
if 'reverse_steps' in final_params:
cmd.extend(["--reverse_steps", str(final_params['reverse_steps'])])
if 'batch_size' in final_params:
cmd.extend(["--batch_size", str(final_params['batch_size'])])
if 'batch_num' in final_params:
cmd.extend(["--batch_num", str(final_params['batch_num'])])
if 'num_workers' in final_params:
cmd.extend(["--num_workers", str(final_params['num_workers'])])
if 'split' in final_params:
cmd.extend(["--split", final_params['split']])
# Anomaly detection parameters
if 'anomaly_binary_threshold' in final_params:
cmd.extend(["--anomaly_binary_threshold", str(final_params['anomaly_binary_threshold'])])
if 'anomaly_pixel_num_threshold' in final_params:
cmd.extend(["--anomaly_pixel_num_threshold", str(final_params['anomaly_pixel_num_threshold'])])
if 'adaptive_threshold' in final_params:
cmd.extend(["--adaptive_threshold", str(final_params['adaptive_threshold'])])
# Output parameters
if 'results_dir' in final_params:
cmd.extend(["--results_dir", final_params['results_dir']])
if 'tag' in final_params:
cmd.extend(["--tag", final_params['tag']])
if 'enable_excel_report' in final_params and final_params['enable_excel_report']:
cmd.append("--enable_excel_report")
if 'enable_save_image_results' in final_params and final_params['enable_save_image_results']:
cmd.append("--enable_save_image_results")
if 'enable_save_whole_image_results' in final_params and final_params['enable_save_whole_image_results']:
cmd.append("--enable_save_whole_image_results")
if 'save_reconstructions' in final_params and final_params['save_reconstructions']:
cmd.append("--save_reconstructions")
if 'save_reconstructions_npy' in final_params and final_params['save_reconstructions_npy']:
cmd.append("--save_reconstructions_npy")
if 'save_raw_anomaly_maps' in final_params and final_params['save_raw_anomaly_maps']:
cmd.append("--save_raw_anomaly_maps")
if 'save_image_variants' in final_params and final_params['save_image_variants']:
cmd.extend(["--save-image-variants"] + [str(v) for v in final_params['save_image_variants']])
if 'save_colormaps' in final_params and final_params['save_colormaps']:
cmd.extend(["--save-colormaps"] + [str(v) for v in final_params['save_colormaps']])
if 'save_normalization' in final_params and final_params['save_normalization']:
cmd.extend(["--save-normalization", str(final_params['save_normalization'])])
if 'normal_center' in final_params and final_params['normal_center'] is not None:
cmd.extend(["--normal-center", str(final_params['normal_center'])])
if 'save_binary_thresholds' in final_params and final_params['save_binary_thresholds']:
cmd.extend(["--save-binary-thresholds"] + [str(v) for v in final_params['save_binary_thresholds']])
if 'enable_confusion_matrix' in final_params and final_params['enable_confusion_matrix']:
cmd.append("--enable_confusion_matrix")
# Filename strategy
if 'filename_strategy' in final_params:
cmd.extend(["--filename_strategy", final_params['filename_strategy']])
if 'regression_test_mode' in final_params and final_params['regression_test_mode']:
cmd.append("--regression_test_mode")
# Bootstrap parameters
if 'bootstrap_samples' in final_params and final_params['bootstrap_samples'] is not None:
cmd.extend(["--bootstrap_samples", str(final_params['bootstrap_samples'])])
if 'bootstrap_seed' in final_params:
cmd.extend(["--bootstrap_seed", str(final_params['bootstrap_seed'])])
if 'bootstrap_sampling_trial_number' in final_params:
cmd.extend(["--bootstrap_sampling_trial_number", str(final_params['bootstrap_sampling_trial_number'])])
# 7. Show standalone alternative
show_standalone_alternatives(cmd, f"Evaluation Processing ({eval_mode}) on {eval_dataset}")
# 8. Execute
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.command('reconstruct')
@click.option('--output-dir', type=click.Path(), required=True,
help='Output directory for results')
@click.option('--model-name', help='Model name')
@click.option('--vqvae-checkpoint', type=click.Path(exists=True),
help='Path to VQ-VAE checkpoint for LDM')
@click.option('--num-inference-steps', type=int, default=100,
help='Number of diffusion inference steps')
@click.option('--batch-size', type=int, default=256,
help='Batch size for reconstruction')
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
def reconstruct(output_dir, model_name, vqvae_checkpoint, num_inference_steps,
batch_size, dry_run):
"""Reconstruct images using diffusion model.
Orchestrates dioodmi-reconstruct with appropriate configuration.
Examples:
hae reconstruct --output-dir results/ --model-name my_model
hae reconstruct --output-dir results/ --vqvae-checkpoint models/vqvae.pth
"""
console.print(f"[bold cyan]🔄 Reconstructing images[/bold cyan]\n")
# Build command
cmd = ["dioodmi-reconstruct", "--output_dir", output_dir]
if model_name:
cmd.extend(["--model_name", model_name])
if vqvae_checkpoint:
cmd.extend(["--vqvae_checkpoint", vqvae_checkpoint])
if num_inference_steps:
cmd.extend(["--num_inference_steps", str(num_inference_steps)])
if batch_size:
cmd.extend(["--batch_size", str(batch_size)])
# Show standalone alternatives
show_standalone_alternatives(cmd, "Reconstruct images with diffusion")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
# Execute command
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
class BenchmarkPreprocessingCommand(click.Command):
"""Command that preprocesses --latent_sizes arguments to handle space-separated values."""
def make_context(self, info_name, args, parent=None, **extra):
args = preprocess_latent_sizes_args(list(args))
return super().make_context(info_name, args, parent=parent, **extra)
@main.command('benchmark', cls=BenchmarkPreprocessingCommand)
@click.option('--dataset', type=click.Choice(['mvtec', 'visa']), required=False,
help='Dataset to benchmark (for evaluation benchmarking)')
@click.option('--models', multiple=True,
help='Models to benchmark (can specify multiple, for evaluation benchmarking)')
@click.option('--categories', multiple=True,
help='Categories to benchmark (default: all, for evaluation benchmarking)')
@click.option('--metrics', multiple=True, default=['auroc'],
help='Metrics to measure (for evaluation benchmarking)')
@click.option('--output', type=click.Path(), default='benchmarks/',
help='Benchmark results output directory')
# Inference speed benchmarking options
@click.option('--model', type=str,
help='Model to benchmark for inference speed (UNet_L, TSFormer_L, WaveMamba_L, UNetFaster_L, compare)')
@click.option('--latent_sizes', multiple=True, type=int,
help='Latent sizes to test (can specify space-separated, e.g., --latent_sizes 16 32 64)')
@click.option('--ncls', type=int, default=15,
help='Number of classes (default: 15)')
@click.option('--batch_size', type=int, default=4,
help='Batch size (default: 4)')
@click.option('--num_warmup', type=int, default=10,
help='Number of warmup iterations (default: 10)')
@click.option('--num_iterations', type=int, default=100,
help='Number of benchmark iterations (default: 100)')
@click.option('--output_dir', type=click.Path(),
help='Output directory for inference speed benchmark results')
def benchmark(dataset, models, categories, metrics, output, model, latent_sizes, ncls, batch_size, num_warmup, num_iterations, output_dir):
"""Run performance benchmarks on multiple models.
Supports two modes:
1. Evaluation benchmarking: Compare models on datasets (--dataset required)
2. Inference speed benchmarking: Measure inference speed across different sizes (--model required)
Examples:
# Evaluation benchmarking
hae benchmark --dataset mvtec --models ddpm --models decodiff
hae benchmark --dataset mvtec --categories bottle --categories cable
# Inference speed benchmarking
hae benchmark --model UNet_L --latent_sizes 16 32 64 128 --ncls 15 --batch_size 4 --output_dir ./results
hae benchmark --model compare --output_dir ./results
"""
# Determine which mode to use
if model is not None:
# Inference speed benchmarking mode
_run_inference_speed_benchmark(model, latent_sizes, ncls, batch_size, num_warmup, num_iterations, output_dir or output)
elif dataset is not None:
# Evaluation benchmarking mode
console.print(f"[bold yellow]⚡ Benchmarking on {dataset}[/bold yellow]\n")
if models:
console.print(f"Models: {', '.join(models)}")
if categories:
console.print(f"Categories: {', '.join(categories)}")
console.print(f"Metrics: {', '.join(metrics)}")
# TODO: Implement benchmarking workflow
console.print("\n[red]Evaluation benchmarking workflow not yet implemented in HAE[/red]")
console.print("This will orchestrate multiple dioodmi-eval runs and aggregate results")
else:
console.print("[red]Error: Either --dataset (for evaluation benchmarking) or --model (for inference speed benchmarking) must be specified[/red]")
console.print("\nUse --help for usage information")
sys.exit(1)
def _run_inference_speed_benchmark(model, latent_sizes, ncls, batch_size, num_warmup, num_iterations, output_dir):
"""Run inference speed benchmarking."""
from pathlib import Path
console.print(f"[bold yellow]⚡ Inference Speed Benchmarking[/bold yellow]\n")
console.print(f"Model: {model}")
if latent_sizes:
console.print(f"Latent sizes: {', '.join(map(str, latent_sizes))}")
console.print(f"Classes: {ncls}")
console.print(f"Batch size: {batch_size}")
console.print(f"Warmup iterations: {num_warmup}")
console.print(f"Benchmark iterations: {num_iterations}")
console.print(f"Output directory: {output_dir}\n")
# Find the benchmark script
# Try multiple paths: relative to current dir, then relative to hae-py location
script_name = "benchmark_inference_speed_multi_size.py"
script_path = None
# Try relative to current working directory
cwd_script = Path.cwd() / "experiments" / "definitions" / "scripts" / script_name
if cwd_script.exists():
script_path = cwd_script
# Try relative to hae-py source location (go up from hae/cli.py to project root)
if script_path is None:
hae_py_root = Path(__file__).parent.parent.parent.parent
hae_script = hae_py_root / "experiments" / "definitions" / "scripts" / script_name
if hae_script.exists():
script_path = hae_script
# Try absolute path from current directory
if script_path is None:
abs_script = Path("definitions/scripts") / script_name
if abs_script.exists():
script_path = abs_script.resolve()
if script_path is None or not script_path.exists():
console.print(f"[red]Error: Benchmark script not found[/red]")
console.print(f"Looked for: {script_name}")
console.print(f" - {cwd_script}")
console.print(f" - {hae_script if 'hae_script' in locals() else 'N/A'}")
console.print("\nMake sure you're running from the project root directory")
sys.exit(1)
# Build command arguments (without Python interpreter for display)
cmd_args = [str(script_path), "--model", model]
if latent_sizes:
# Convert tuple to list of strings for the script (which uses nargs='+')
cmd_args.extend(["--latent_sizes"] + [str(s) for s in latent_sizes])
cmd_args.extend([
"--ncls", str(ncls),
"--batch_size", str(batch_size),
"--num_warmup", str(num_warmup),
"--num_iterations", str(num_iterations),
"--output_dir", str(output_dir)
])
# Show standalone alternatives (display only, for user reference)
console.print("[bold]💡 Standalone alternatives:[/bold]")
console.print(f" Direct: {' '.join(cmd_args)}")
console.print(f" Python: python3 {' '.join(cmd_args)}")
console.print()
# Execute with sys.executable for cross-platform compatibility
# This works on both Windows (where .py files can't be executed directly)
# and NixOS/Unix (ensures using the same Python environment)
exec_cmd = [sys.executable] + cmd_args
console.print(f"[bold]Executing benchmark...[/bold]\n")
try:
result = subprocess.run(exec_cmd, check=True)
console.print(f"\n[green]✓ Benchmark completed successfully[/green]")
console.print(f"Results saved to: {output_dir}")
except subprocess.CalledProcessError as e:
console.print(f"\n[red]✗ Benchmark failed with exit code {e.returncode}[/red]")
sys.exit(e.returncode)
except FileNotFoundError:
console.print(f"\n[red]✗ Benchmark script not found: {script_path}[/red]")
console.print("Make sure you're running from the project root directory")
sys.exit(1)
@main.group()
def examples():
"""Manage and run example configurations.
Examples provide pre-configured workflows for common use cases.
"""
pass
@examples.command('list')
def examples_list():
"""List available examples."""
console.print("[bold]📚 Available Examples[/bold]\n")
# TODO: Load from examples/ directory
table = Table(title="Example Configurations")
table.add_column("Name", style="cyan")
table.add_column("Description", style="green")
table.add_column("Dataset", style="yellow")
# Placeholder examples
table.add_row("mvtec-bottle", "Train DeCo-Diff on MVTec bottle", "MVTec-AD")
table.add_row("visa-candle", "Train DeCo-Diff on VisA candle", "VisA")
table.add_row("quick-test", "Quick test run (1 epoch)", "MVTec-AD")
console.print(table)
console.print("\n[dim]Use 'hae examples run <name>' to execute an example[/dim]")
@examples.command('info')
@click.argument('name')
def examples_info(name):
"""Show details about an example."""
console.print(f"[bold]📖 Example: {name}[/bold]\n")
# TODO: Load from examples/ directory
console.print("[red]Example info not yet implemented[/red]")
console.print(f"Will show details about example '{name}'")
@examples.command('run')
@click.argument('name')
@click.option('--train/--no-train', default=False, help='Run training')
@click.option('--eval/--no-eval', default=False, help='Run evaluation')
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
def examples_run(name, train, eval, dry_run):
"""Run an example configuration.
Examples:
hae examples run mvtec-bottle --train
hae examples run mvtec-bottle --eval
hae examples run quick-test --train --eval
"""
console.print(f"[bold]🎯 Running Example: {name}[/bold]\n")
operations = []
if train:
operations.append("Train")
if eval:
operations.append("Eval")
if not operations:
console.print("[red]Error: Specify at least one operation (--train or --eval)[/red]")
sys.exit(1)
# TODO: Load example config and execute
console.print("[red]Example execution not yet implemented[/red]")
console.print(f"Will execute: {', '.join(operations)}")
@examples.command('export')
@click.argument('name')
@click.option('--output', type=click.Path(), default='.',
help='Directory to export configs to')
def examples_export(name, output):
"""Export example configs to a directory.
Examples:
hae examples export mvtec-bottle
hae examples export mvtec-bottle --output my-experiment/
"""
console.print(f"[bold]📤 Exporting Example: {name}[/bold]\n")
# TODO: Copy example configs
console.print("[red]Example export not yet implemented[/red]")
console.print(f"Will export configs to: {output}")
@main.group()
def dev():
"""Development tools and utilities.
Provides interactive tools, code quality checks, and utilities for DiOodMi development.
"""
pass
@dev.command('shell')
def dev_shell():
"""Start interactive Python shell with DiOodMi loaded.
Example:
hae dev shell
"""
console.print("[bold cyan]🐍 Starting Python shell with DiOodMi[/bold cyan]\n")
# Build command for IPython with dioodmi preloaded
cmd = ["python", "-m", "IPython", "-i", "-c",
"import dioodmi; from dioodmi.trainers import *; from dioodmi.networks import *; print('DiOodMi loaded')"]
show_standalone_alternatives(["python", "-m", "IPython"], "Interactive Python shell")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dev.command('jupyter')
@click.option('--port', type=int, default=8888, help='Port for Jupyter server')
@click.option('--no-browser', is_flag=True, help='Don\'t open browser automatically')
def dev_jupyter(port, no_browser):
"""Launch Jupyter Lab for interactive development.
Example:
hae dev jupyter
hae dev jupyter --port 9000 --no-browser
"""
console.print(f"[bold cyan]📓 Launching Jupyter Lab on port {port}[/bold cyan]\n")
cmd = ["jupyter", "lab", f"--port={port}"]
if no_browser:
cmd.append("--no-browser")
show_standalone_alternatives(cmd[:2], f"Jupyter Lab on port {port}")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dev.command('format')
@click.option('--check', is_flag=True, help='Check only, don\'t modify files')
@click.option('--path', type=click.Path(exists=True), default='.',
help='Path to format')
def dev_format(check, path):
"""Format code with ruff.
Example:
hae dev format
hae dev format --check
"""
console.print("[bold green]✨ Formatting code with ruff[/bold green]\n")
cmd = ["ruff", "format"]
if check:
cmd.append("--check")
cmd.append(path)
show_standalone_alternatives(cmd, "Format code with ruff")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print("\n[green]✓ Formatting complete[/green]")
sys.exit(result.returncode)
@dev.command('lint')
@click.option('--fix', is_flag=True, help='Automatically fix issues')
@click.option('--path', type=click.Path(exists=True), default='.',
help='Path to lint')
def dev_lint(fix, path):
"""Lint code with ruff.
Example:
hae dev lint
hae dev lint --fix
"""
console.print("[bold yellow]🔍 Linting code with ruff[/bold yellow]\n")
cmd = ["ruff", "check"]
if fix:
cmd.append("--fix")
cmd.append(path)
show_standalone_alternatives(cmd, "Lint code with ruff")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print("\n[green]✓ No linting issues found[/green]")
sys.exit(result.returncode)
@dev.command('test')
@click.option('--path', type=click.Path(exists=True), default='tests/',
help='Test path or file')
@click.option('--coverage', is_flag=True, help='Run with coverage report')
@click.option('--verbose', '-v', is_flag=True, help='Verbose output')
@click.option('--keyword', '-k', help='Run tests matching keyword')
def dev_test(path, coverage, verbose, keyword):
"""Run tests with pytest.
Example:
hae dev test
hae dev test --coverage
hae dev test -k "test_decodiff"
"""
console.print("[bold blue]🧪 Running tests with pytest[/bold blue]\n")
cmd = ["pytest"]
if coverage:
cmd.extend(["--cov=dioodmi", "--cov-report=term-missing"])
if verbose:
cmd.append("-v")
if keyword:
cmd.extend(["-k", keyword])
cmd.append(path)
show_standalone_alternatives(cmd[:1], "Run pytest tests")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dev.command('typecheck')
@click.option('--path', type=click.Path(exists=True), default='.',
help='Path to type check')
def dev_typecheck(path):
"""Run mypy type checking.
Example:
hae dev typecheck
"""
console.print("[bold magenta]🔎 Type checking with mypy[/bold magenta]\n")
cmd = ["mypy", path]
show_standalone_alternatives(cmd, "Type check with mypy")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print("\n[green]✓ Type checking passed[/green]")
sys.exit(result.returncode)
@dev.command('clean')
@click.option('--deep', is_flag=True, help='Also remove .venv, build artifacts')
def dev_clean(deep):
"""Clean build artifacts and caches.
Example:
hae dev clean
hae dev clean --deep
"""
console.print("[bold red]🧹 Cleaning build artifacts[/bold red]\n")
import shutil
from pathlib import Path
patterns = [
"**/__pycache__",
"**/*.pyc",
"**/*.pyo",
"**/.pytest_cache",
"**/.ruff_cache",
"**/.mypy_cache",
"**/htmlcov",
"**/.coverage",
]
if deep:
patterns.extend([
"**/.venv",
"**/build",
"**/dist",
"**/*.egg-info",
])
removed_count = 0
for pattern in patterns:
for path in Path(".").glob(pattern):
if path.exists():
console.print(f" Removing: {path}")
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
removed_count += 1
console.print(f"\n[green]✓ Cleaned {removed_count} items[/green]")
@dev.command('docs')
@click.option('--serve', is_flag=True, help='Serve docs locally')
@click.option('--port', type=int, default=8000, help='Port for doc server')
def dev_docs(serve, port):
"""Generate documentation.
Example:
hae dev docs
hae dev docs --serve
"""
console.print("[bold cyan]📚 Generating documentation[/bold cyan]\n")
# TODO: Implement documentation generation
console.print("[yellow]Documentation generation not yet implemented[/yellow]")
console.print("Will use Sphinx or mkdocs to generate from docstrings")
@dev.command('check-deps')
def dev_check_deps():
"""Check dependencies for updates and security issues.
Example:
hae dev check-deps
"""
console.print("[bold yellow]📦 Checking dependencies[/bold yellow]\n")
# Check for outdated packages
console.print("[cyan]Checking for updates...[/cyan]")
cmd = ["uv", "pip", "list", "--outdated"]
show_standalone_alternatives(cmd, "Check outdated packages")
subprocess.run(cmd)
console.print("\n[cyan]Security audit coming soon...[/cyan]")
@dev.command('profile')
@click.argument('script', type=click.Path(exists=True))
@click.option('--sort', default='cumulative', help='Sort order (cumulative, time, calls)')
def dev_profile(script, sort):
"""Profile a Python script for performance analysis.
Example:
hae dev profile project/train_decodiff.py
"""
console.print(f"[bold magenta]⚡ Profiling {script}[/bold magenta]\n")
cmd = ["python", "-m", "cProfile", "-s", sort, script]
show_standalone_alternatives(cmd, f"Profile {script}")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dev.command('check-cuda')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def dev_check_cuda(dry_run):
"""Check CUDA availability and GPU information.
Displays PyTorch CUDA version, cuDNN version, number of GPUs,
and detailed information about each GPU (name, compute capability, memory).
Example:
hae dev check-cuda
"""
console.print("[bold green]🎮 Checking CUDA Setup[/bold green]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "check-cuda.py"
cmd = ["python", str(script_path)]
show_standalone_alternatives(cmd, "Check CUDA availability and GPU info")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dev.command('test-cufft')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def dev_test_cufft(dry_run):
"""Test CUDA FFT functionality.
Quick test to verify cuFFT (CUDA Fast Fourier Transform) is working
correctly by running a simple FFT operation on GPU.
Example:
hae dev test-cufft
"""
console.print("[bold green]📊 Testing cuFFT[/bold green]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "cufftTest.py"
cmd = ["python", str(script_path)]
show_standalone_alternatives(cmd, "Test CUDA FFT functionality")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dev.command('sort-interactive')
@click.option('--pred-jsonl', required=True, type=click.Path(exists=True), help='Prediction JSONL file')
@click.option('--out-dir', required=True, type=click.Path(), help='Output directory for sorted images')
@click.option('--zoom', type=float, default=1.0, help='Zoom factor for display')
@click.option('--resume', is_flag=True, help='Resume from previous session')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def dev_sort_interactive(pred_jsonl, out_dir, zoom, resume, dry_run):
"""Interactive image sorter for anomaly detection results.
Opens an interactive GUI for manually sorting and categorizing
anomaly detection results (good, false positive, missed, etc.).
Examples:
hae dev sort-interactive --pred-jsonl predictions.jsonl --out-dir ./sorted
hae dev sort-interactive --pred-jsonl pred.jsonl --out-dir ./sorted --zoom 1.5 --resume
"""
console.print("[bold blue]🖱️ Interactive Sorter[/bold blue]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "interactive_sorter.py"
cmd = [
"python", str(script_path),
"--pred_jsonl", pred_jsonl,
"--out_dir", out_dir,
"--zoom", str(zoom)
]
if resume:
cmd.append("--resume")
show_standalone_alternatives(cmd, "Interactive image sorter")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dev.command('test-evolving-patch')
@click.argument('args', nargs=-1, type=click.UNPROCESSED)
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def dev_test_evolving_patch(args, dry_run):
"""Test evolving patch position encoding.
Tests the position encoding mechanism for patches that evolve
during training (dropout, augmentation, etc.).
Example:
hae dev test-evolving-patch
hae dev test-evolving-patch --help
"""
console.print("[bold magenta]🧪 Testing Evolving Patch[/bold magenta]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "scripts" / "test_evolving_patch.py"
cmd = ["python", str(script_path)] + list(args)
show_standalone_alternatives(cmd, "Test evolving patch position encoding")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.group()
def dvc():
"""DVC (Data Version Control) operations.
Manage datasets, models, pipelines, and experiments with DVC.
"""
pass
@dvc.command('status')
@click.option('--cloud', is_flag=True, help='Show status of cloud storage')
def dvc_status(cloud):
"""Show DVC status (tracked files, changes).
Example:
hae dvc status
hae dvc status --cloud
"""
console.print("[bold cyan]📊 DVC Status[/bold cyan]\n")
cmd = ["dvc", "status"]
if cloud:
cmd.append("--cloud")
show_standalone_alternatives(cmd, "Check DVC status")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dvc.command('add')
@click.argument('path', type=click.Path(exists=True))
@click.option('--recursive', '-r', is_flag=True, help='Add directory recursively')
def dvc_add(path, recursive):
"""Track file or directory with DVC.
Example:
hae dvc add data/mvtec-dataset/
hae dvc add models/checkpoint.pth
"""
console.print(f"[bold green]➕ Adding {path} to DVC[/bold green]\n")
cmd = ["dvc", "add", path]
if recursive:
cmd.append("-r")
show_standalone_alternatives(cmd, f"Track {path} with DVC")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print(f"\n[green]✓ Added {path} to DVC[/green]")
console.print(f"[yellow]Don't forget to:[/yellow]")
console.print(f" git add {path}.dvc .gitignore")
console.print(f" git commit -m 'Add {path} to DVC'")
sys.exit(result.returncode)
@dvc.command('push')
@click.option('--remote', '-r', help='Push to specific remote')
@click.option('--all-branches', is_flag=True, help='Push all branches')
@click.option('--all-tags', is_flag=True, help='Push all tags')
def dvc_push(remote, all_branches, all_tags):
"""Push tracked data to remote storage.
Example:
hae dvc push
hae dvc push --remote myremote
"""
console.print("[bold blue]⬆️ Pushing to DVC remote[/bold blue]\n")
cmd = ["dvc", "push"]
if remote:
cmd.extend(["--remote", remote])
if all_branches:
cmd.append("--all-branches")
if all_tags:
cmd.append("--all-tags")
show_standalone_alternatives(cmd, "Push data to DVC remote")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print("\n[green]✓ Push complete[/green]")
sys.exit(result.returncode)
@dvc.command('pull')
@click.option('--remote', '-r', help='Pull from specific remote')
@click.option('--all-branches', is_flag=True, help='Pull all branches')
@click.option('--all-tags', is_flag=True, help='Pull all tags')
def dvc_pull(remote, all_branches, all_tags):
"""Pull tracked data from remote storage.
Example:
hae dvc pull
hae dvc pull --remote myremote
"""
console.print("[bold blue]⬇️ Pulling from DVC remote[/bold blue]\n")
cmd = ["dvc", "pull"]
if remote:
cmd.extend(["--remote", remote])
if all_branches:
cmd.append("--all-branches")
if all_tags:
cmd.append("--all-tags")
show_standalone_alternatives(cmd, "Pull data from DVC remote")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print("\n[green]✓ Pull complete[/green]")
sys.exit(result.returncode)
@dvc.command('checkout')
@click.option('--force', is_flag=True, help='Force checkout even with conflicts')
def dvc_checkout(force):
"""Checkout data files tracked by DVC.
Example:
hae dvc checkout
hae dvc checkout --force
"""
console.print("[bold cyan]🔄 Checking out DVC files[/bold cyan]\n")
cmd = ["dvc", "checkout"]
if force:
cmd.append("--force")
show_standalone_alternatives(cmd, "Checkout DVC tracked files")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print("\n[green]✓ Checkout complete[/green]")
sys.exit(result.returncode)
@dvc.command('repro')
@click.option('--force', is_flag=True, help='Reproduce even if dependencies unchanged')
@click.option('--downstream', is_flag=True, help='Reproduce downstream stages')
@click.argument('target', required=False)
def dvc_repro(force, downstream, target):
"""Reproduce DVC pipeline.
Example:
hae dvc repro
hae dvc repro train
hae dvc repro --force --downstream
"""
console.print("[bold magenta]🔁 Reproducing DVC pipeline[/bold magenta]\n")
cmd = ["dvc", "repro"]
if force:
cmd.append("--force")
if downstream:
cmd.append("--downstream")
if target:
cmd.append(target)
show_standalone_alternatives(cmd, "Reproduce DVC pipeline")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print("\n[green]✓ Pipeline reproduced[/green]")
sys.exit(result.returncode)
@dvc.command('run')
@click.option('--name', '-n', required=True, help='Stage name')
@click.option('--deps', '-d', multiple=True, help='Dependencies')
@click.option('--outs', '-o', multiple=True, help='Outputs')
@click.option('--metrics', '-m', multiple=True, help='Metrics files')
@click.option('--params', '-p', multiple=True, help='Parameters')
@click.argument('command', nargs=-1, required=True)
def dvc_run(name, deps, outs, metrics, params, command):
"""Create a DVC pipeline stage.
Example:
hae dvc run -n train -d data/ -o models/ -m metrics.json -- python train.py
"""
console.print(f"[bold green]▶️ Creating stage '{name}'[/bold green]\n")
cmd = ["dvc", "run", "-n", name]
for dep in deps:
cmd.extend(["-d", dep])
for out in outs:
cmd.extend(["-o", out])
for metric in metrics:
cmd.extend(["-m", metric])
for param in params:
cmd.extend(["-p", param])
cmd.extend(list(command))
show_standalone_alternatives(cmd, f"Create DVC stage '{name}'")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print(f"\n[green]✓ Stage '{name}' created[/green]")
sys.exit(result.returncode)
@dvc.group('exp')
def dvc_exp():
"""DVC experiment tracking commands."""
pass
@dvc_exp.command('run')
@click.option('--name', '-n', help='Experiment name')
@click.option('--queue', is_flag=True, help='Queue experiment instead of running')
@click.option('--set-param', '-S', multiple=True, help='Override parameter (key=value)')
def dvc_exp_run(name, queue, set_param):
"""Run a DVC experiment.
Example:
hae dvc exp run -n my_exp
hae dvc exp run -S learning_rate=0.001 -S epochs=100
"""
console.print("[bold magenta]🧪 Running DVC experiment[/bold magenta]\n")
cmd = ["dvc", "exp", "run"]
if name:
cmd.extend(["--name", name])
if queue:
cmd.append("--queue")
for param in set_param:
cmd.extend(["-S", param])
show_standalone_alternatives(cmd, "Run DVC experiment")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dvc_exp.command('show')
@click.option('--all-branches', is_flag=True, help='Show experiments from all branches')
@click.option('--all-commits', is_flag=True, help='Show experiments from all commits')
def dvc_exp_show(all_branches, all_commits):
"""Show experiments and their metrics.
Example:
hae dvc exp show
hae dvc exp show --all-commits
"""
console.print("[bold cyan]📊 DVC Experiments[/bold cyan]\n")
cmd = ["dvc", "exp", "show"]
if all_branches:
cmd.append("--all-branches")
if all_commits:
cmd.append("--all-commits")
show_standalone_alternatives(cmd, "Show DVC experiments")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dvc_exp.command('diff')
@click.argument('experiment', required=False)
@click.option('--all', is_flag=True, help='Show all metrics and params')
def dvc_exp_diff(experiment, all):
"""Show difference between experiments.
Example:
hae dvc exp diff
hae dvc exp diff exp-abc123
"""
console.print("[bold yellow]🔍 Experiment Diff[/bold yellow]\n")
cmd = ["dvc", "exp", "diff"]
if experiment:
cmd.append(experiment)
if all:
cmd.append("--all")
show_standalone_alternatives(cmd, "Compare DVC experiments")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dvc_exp.command('apply')
@click.argument('experiment', required=True)
def dvc_exp_apply(experiment):
"""Apply an experiment to workspace.
Example:
hae dvc exp apply exp-abc123
"""
console.print(f"[bold green]✨ Applying experiment {experiment}[/bold green]\n")
cmd = ["dvc", "exp", "apply", experiment]
show_standalone_alternatives(cmd, f"Apply experiment {experiment}")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print(f"\n[green]✓ Experiment {experiment} applied[/green]")
sys.exit(result.returncode)
def human_size(bytes_size):
"""Convert bytes to human-readable format."""
if bytes_size is None:
return "?"
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_size < 1024.0:
return f"{bytes_size:.1f}{unit}"
bytes_size /= 1024.0
return f"{bytes_size:.1f}PB"
def parse_dvc_files(filter_path=None):
"""Parse all .dvc files and return metadata."""
import yaml
dvc_files = []
for dvc_file in Path('.').rglob('*.dvc'):
# Skip .dvc directory itself
if dvc_file.is_dir():
continue
if filter_path and filter_path not in str(dvc_file):
continue
try:
with open(dvc_file) as f:
content = yaml.safe_load(f)
if 'outs' in content and len(content['outs']) > 0:
out = content['outs'][0]
data_path = dvc_file.parent / out.get('path', '')
dvc_files.append({
'dvc_file': str(dvc_file),
'data_path': str(data_path),
'size': out.get('size'),
'nfiles': out.get('nfiles', 1),
'exists': data_path.exists()
})
except Exception as e:
console.print(f"[yellow]Warning: Could not parse {dvc_file}: {e}[/yellow]")
return sorted(dvc_files, key=lambda x: x['dvc_file'])
def interactive_dvc_selection(dvc_files):
"""Interactive selection of DVC files for pull/push."""
from InquirerPy import inquirer
# Create choices with status indicators
choices = []
for item in dvc_files:
status = "✓" if item['exists'] else "✗"
size = human_size(item['size'])
label = f"{status} {item['dvc_file']:60s} ({size})"
choices.append({
'name': label,
'value': item['dvc_file'],
'enabled': False
})
# Multi-select checkbox
selected = inquirer.checkbox(
message="Select datasets (space to toggle, enter to confirm):",
choices=choices,
validate=lambda x: len(x) > 0 or "Please select at least one dataset"
).execute()
if not selected:
return None, None
# Ask for action
action = inquirer.select(
message="What would you like to do?",
choices=[
{'name': 'Pull (download from remote)', 'value': 'pull'},
{'name': 'Push (upload to remote)', 'value': 'push'},
{'name': 'Cancel', 'value': 'cancel'}
]
).execute()
return selected, action
@dvc.command('list')
@click.option('--select', '-s', is_flag=True, help='Interactive selection for pull/push')
@click.option('--filter', help='Filter by path (datasets, fixtures, golden)')
@click.option('--remote-only', is_flag=True, help='Show only remote datasets')
@click.option('--local-only', is_flag=True, help='Show only local datasets')
@click.option('--recursive', '-R', is_flag=True, help='(Compatibility - always recursive)')
@click.option('--tree', '-T', is_flag=True, help='(Compatibility - ignored)')
@click.option('--size', is_flag=True, help='(Compatibility - always shown)')
def dvc_list_files(select, filter, remote_only, local_only, recursive, tree, size):
"""List all DVC-tracked datasets with detailed information.
Examples:
hae dvc list # Show all datasets
hae dvc list --select # Interactive pull/push
hae dvc list --filter datasets # Filter by path
hae dvc list --remote-only # Show only remote datasets
"""
console.print("[bold cyan]📁 DVC Tracked Datasets[/bold cyan]\n")
# Parse all .dvc files
dvc_files = parse_dvc_files(filter_path=filter)
# Apply filters
if remote_only:
dvc_files = [f for f in dvc_files if not f['exists']]
if local_only:
dvc_files = [f for f in dvc_files if f['exists']]
if not dvc_files:
console.print("[yellow]No DVC files found matching criteria[/yellow]")
return
# Create Rich table
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Status", style="cyan", width=12)
table.add_column("Size", style="yellow", width=10)
table.add_column("Files", style="green", width=10)
table.add_column("DVC File", style="white")
for item in dvc_files:
status = "[green]✓ LOCAL[/green]" if item['exists'] else "[red]✗ REMOTE[/red]"
size_str = human_size(item['size'])
nfiles = str(item['nfiles']) if item['nfiles'] > 1 else "-"
table.add_row(status, size_str, nfiles, item['dvc_file'])
console.print(table)
console.print(f"\n[bold]Total:[/bold] {len(dvc_files)} DVC-tracked datasets")
# Interactive selection mode
if select:
console.print()
selected_files, action = interactive_dvc_selection(dvc_files)
if action == 'cancel' or not selected_files:
console.print("[yellow]Operation cancelled[/yellow]")
return
# Execute DVC command
cmd = ["dvc", action] + selected_files
console.print(f"\n[cyan]Running:[/cyan] {' '.join(cmd)}\n")
result = subprocess.run(cmd)
if result.returncode == 0:
console.print(f"\n[green]✓ Successfully {action}ed {len(selected_files)} dataset(s)[/green]")
else:
console.print(f"\n[red]✗ {action} failed with code {result.returncode}[/red]")
sys.exit(result.returncode)
else:
# Show pull command hint for remote files
remote_count = sum(1 for f in dvc_files if not f['exists'])
if remote_count > 0:
console.print(f"\n[yellow]💡 {remote_count} datasets are remote. Use --select for interactive pull:[/yellow]")
console.print(" hae dvc list --select")
@dvc.command('dag')
def dvc_dag():
"""Show DVC pipeline DAG.
Example:
hae dvc dag
"""
console.print("[bold cyan]🔀 DVC Pipeline DAG[/bold cyan]\n")
cmd = ["dvc", "dag"]
show_standalone_alternatives(cmd, "Show pipeline DAG")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dvc.command('metrics')
@click.option('--show-md', is_flag=True, help='Show as markdown table')
@click.option('--all-branches', is_flag=True, help='Show metrics from all branches')
def dvc_metrics(show_md, all_branches):
"""Show metrics from tracked files.
Example:
hae dvc metrics
hae dvc metrics --show-md
"""
console.print("[bold magenta]📈 DVC Metrics[/bold magenta]\n")
cmd = ["dvc", "metrics", "show"]
if show_md:
cmd.append("--show-md")
if all_branches:
cmd.append("--all-branches")
show_standalone_alternatives(cmd, "Show DVC metrics")
result = subprocess.run(cmd)
sys.exit(result.returncode)
@dvc.command('params')
@click.option('--all-branches', is_flag=True, help='Show params from all branches')
def dvc_params(all_branches):
"""Show parameters from tracked files.
Example:
hae dvc params
"""
console.print("[bold yellow]⚙️ DVC Parameters[/bold yellow]\n")
cmd = ["dvc", "params", "show"]
if all_branches:
cmd.append("--all-branches")
show_standalone_alternatives(cmd, "Show DVC parameters")
result = subprocess.run(cmd)
sys.exit(result.returncode)
# ============================================================================
# OOD Detection Commands
# ============================================================================
@main.command('ood-detection')
@click.option('--output-dir', required=True, type=click.Path(),
help='Location for models/results')
@click.option('--model-name', required=True,
help='Name of model to analyze')
@click.option('--max-t', type=int, default=1000,
help='Maximum T to consider reconstructions from')
@click.option('--min-t', type=int, default=0,
help='Minimum T to consider reconstructions from')
@click.option('--t-skip', type=int, default=1,
help='Only use every n reconstructions')
@click.option('--seed', type=int, default=2,
help='Random seed to use')
def ood_detection(output_dir, model_name, max_t, min_t, t_skip, seed):
"""Perform OOD (Out-of-Distribution) detection analysis.
Analyzes reconstruction errors at different timesteps to detect OOD samples.
Example:
hae ood-detection --output-dir ./results --model-name decodiff_mvtec
hae ood-detection --output-dir ./results --model-name model_name --max-t 500
"""
console.print("[bold cyan]🔍 OOD Detection Analysis[/bold cyan]\n")
cmd = [
"python", "ood_detection.py",
"--output_dir", str(output_dir),
"--model_name", model_name,
"--max_t", str(max_t),
"--min_t", str(min_t),
"--t_skip", str(t_skip),
"--seed", str(seed)
]
show_standalone_alternatives(cmd, "Perform OOD detection analysis")
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.command('ood-results')
@click.option('--output-dir', required=True, type=click.Path(),
help='Location for models/results')
@click.option('--model-name', required=True,
help='Name of model to analyze')
@click.option('--out-data', default=None,
help='Custom OOD data CSV')
@click.option('--max-t', type=int, default=1000,
help='Maximum T to consider reconstructions from')
@click.option('--min-t', type=int, default=0,
help='Minimum T to consider reconstructions from')
@click.option('--t-skip', type=int, default=1,
help='Only use every n reconstructions')
@click.option('--seed', type=int, default=2,
help='Random seed to use')
def ood_results(output_dir, model_name, out_data, max_t, min_t, t_skip, seed):
"""Process and visualize OOD detection results.
Creates Excel reports, ROC curves, and PR curves for OOD analysis.
Example:
hae ood-results --output-dir ./results --model-name decodiff_mvtec
hae ood-results --output-dir ./results --model-name model_name --out-data custom.csv
"""
console.print("[bold magenta]📊 OOD Results Processing[/bold magenta]\n")
cmd = [
"python", "ood_results.py",
"--output_dir", str(output_dir),
"--model_name", model_name,
"--max_t", str(max_t),
"--min_t", str(min_t),
"--t_skip", str(t_skip),
"--seed", str(seed)
]
if out_data:
cmd.extend(["--out_data", out_data])
show_standalone_alternatives(cmd, "Process OOD detection results")
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.command('ood-graph')
@click.option('--in-csv', required=True, type=click.Path(exists=True),
help='Path to in-distribution CSV')
@click.option('--out-csv', required=True, type=click.Path(exists=True),
help='Path to out-of-distribution CSV')
@click.option('--save', type=click.Path(),
help='Path to save the figure (optional, shows otherwise)')
def ood_graph(in_csv, out_csv, save):
"""Draw LPIPS / MSE comparison grids for OOD analysis.
Creates 2x2 grid comparing in-distribution vs out-distribution samples
across LPIPS and MSE metrics.
Example:
hae ood-graph --in-csv results/in.csv --out-csv results/out.csv
hae ood-graph --in-csv results/in.csv --out-csv results/out.csv --save fig.png
"""
console.print("[bold green]📈 OOD Visualization[/bold green]\n")
cmd = [
"python", "ood_graph.py",
"--in_csv", str(in_csv),
"--out_csv", str(out_csv)
]
if save:
cmd.extend(["--save", str(save)])
show_standalone_alternatives(cmd, "Draw LPIPS / MSE comparison grids")
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
# ============================================================================
# Data Preprocessing Commands
# ============================================================================
@main.group()
def preprocess():
"""Data preprocessing utilities.
Commands for cropping, splitting, and annotating datasets.
"""
pass
@preprocess.command('crop')
@click.option('--input-dir', required=True, type=click.Path(exists=True),
help='Input directory containing images')
@click.option('--output-dir', required=True, type=click.Path(),
help='Output directory for cropped patches')
@click.option('--patch-size', required=True, type=int,
help='Patch side length (e.g., 128)')
@click.option('--patches-per-image', required=True, type=int,
help='Number of patches to extract per image')
@click.option('--is-grayscale', type=int, default=0,
help='1 for grayscale, 0 for color')
@click.option('--ext', default='png',
help='Output file extension')
@click.option('--seed', type=int,
help='Random seed for reproducibility')
def preprocess_crop(input_dir, output_dir, patch_size, patches_per_image, is_grayscale, ext, seed):
"""Extract random crops from images.
Creates smaller patches from larger images, useful for augmentation
and creating training datasets.
Example:
hae preprocess crop --input-dir ./data/raw --output-dir ./data/crops \\
--patch-size 128 --patches-per-image 2
hae preprocess crop --input-dir ./images --output-dir ./patches \\
--patch-size 256 --patches-per-image 5 --seed 42
"""
console.print("[bold cyan]✂️ Random Crop Extraction[/bold cyan]\n")
cmd = [
"python", "-m", "dioodmi.data.random_crop",
"--input_dir", str(input_dir),
"--output_dir", str(output_dir),
"--patch_size", str(patch_size),
"--patches_per_image", str(patches_per_image),
"--is_grayscale", str(is_grayscale),
"--ext", ext
]
if seed is not None:
cmd.extend(["--seed", str(seed)])
show_standalone_alternatives(cmd, "Extract random crops from images")
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@preprocess.command('split')
@click.argument('folder', type=click.Path(exists=True))
@click.option('--valid-ratio', type=float, default=0.2,
help='Ratio for validation set (default: 0.2)')
@click.option('--test-ratio', type=float, default=0.01,
help='Ratio for test set (default: 0.01)')
def preprocess_split(folder, valid_ratio, test_ratio):
"""Split dataset into train/valid/test CSV files.
Creates CSV files with image paths for train, validation, and test sets.
Example:
hae preprocess split ./data/images
hae preprocess split ./data/images --valid-ratio 0.1 --test-ratio 0.05
"""
console.print("[bold yellow]📂 Dataset Splitting[/bold yellow]\n")
cmd = [
"python", "-m", "dioodmi.data.split_data",
str(folder),
"--valid_ratio", str(valid_ratio),
"--test_ratio", str(test_ratio)
]
show_standalone_alternatives(cmd, "Split dataset into train/valid/test sets")
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@preprocess.command('annotate')
@click.option('--input-dir', '-i', required=True, type=click.Path(exists=True),
help='Directory containing images to annotate')
@click.option('--output-dir', '-o', required=True, type=click.Path(),
help='Directory to save annotation JSON files')
@click.option('--grid-size', '-g', type=int, default=128,
help='Grid size for patches (default: 128)')
@click.option('--extensions', '-e', multiple=True,
default=['.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'],
help='Image file extensions to process')
@click.option('--single-image', '-s', type=click.Path(exists=True),
help='Process a single image instead of directory')
@click.option('--is-defective', is_flag=True, default=False,
help='Mark images as defective')
@click.option('--random-sample', '-r', type=int,
help='Randomly sample N files from directory')
@click.option('--random-seed', type=int,
help='Random seed for reproducible sampling')
def preprocess_annotate(input_dir, output_dir, grid_size, extensions, single_image,
is_defective, random_sample, random_seed):
"""Create annotation JSON files for anomaly detection.
Generates JSON annotations with defective patch information for images.
Example:
hae preprocess annotate -i ./images -o ./annotations
hae preprocess annotate -i ./images -o ./annotations --is-defective
hae preprocess annotate -i ./images -o ./annotations -r 100 --random-seed 42
"""
console.print("[bold magenta]📝 Annotation Creation[/bold magenta]\n")
cmd = ["python", "scripts/create_annotations.py"]
if single_image:
cmd.extend(["--single-image", str(single_image)])
else:
cmd.extend(["--input-dir", str(input_dir)])
cmd.extend([
"--output-dir", str(output_dir),
"--grid-size", str(grid_size)
])
if extensions and extensions != ('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'):
for ext in extensions:
cmd.extend(["--extensions", ext])
if is_defective:
cmd.append("--is-defective")
if random_sample is not None:
cmd.extend(["--random-sample", str(random_sample)])
if random_seed is not None:
cmd.extend(["--random-seed", str(random_seed)])
show_standalone_alternatives(cmd, "Create annotation JSON files")
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@preprocess.group('synthetic-artifacts')
def preprocess_synthetic_artifacts():
"""Generate synthetic defect datasets (scratches, circular artifacts, etc.)."""
pass
@preprocess_synthetic_artifacts.command('scratch')
@click.option('--input-dir', required=True, type=click.Path(exists=True),
help='Directory containing source images (PNG expected)')
@click.option('--output-dir', required=True, type=click.Path(),
help='Directory where synthetic dataset will be written')
@click.option('--severities-json', required=True, type=click.Path(exists=True),
help='JSON file describing severity presets')
@click.option('--seed', type=int,
help='Optional RNG seed for reproducibility')
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
def preprocess_synthetic_artifacts_scratch(input_dir, output_dir, severities_json, seed, dry_run):
"""Generate synthetic scratch defects from clean source images.
Wraps dioodmi's synthetic scratch generator to produce defect variants,
masks, and metadata according to severity presets.
Example:
hae preprocess synthetic-artifacts scratch --input-dir ./clean --output-dir ./scratch_out \\
--severities-json severities_circular_carpet.json
"""
console.print("[bold cyan]🪛 Generating Synthetic Scratches[/bold cyan]\n")
cmd = [
"python", "-m", "dioodmi.data.synthetic_scratch",
"--input_dir", str(input_dir),
"--output_dir", str(output_dir),
"--severities_json", str(severities_json)
]
if seed is not None:
cmd.extend(["--seed", str(seed)])
show_standalone_alternatives(cmd, "Generate synthetic scratch dataset")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@preprocess_synthetic_artifacts.command('circular')
@click.option('--input-dir', required=True, type=click.Path(exists=True),
help='Directory containing source images (PNG expected)')
@click.option('--output-dir', required=True, type=click.Path(),
help='Directory where synthetic dataset will be written')
@click.option('--severities-json', required=True, type=click.Path(exists=True),
help='JSON file describing severity presets')
@click.option('--seed', type=int,
help='Optional RNG seed for reproducibility')
@click.option('--mvtec-format/--no-mvtec-format', default=False,
help='Output using MVTec AD folder structure')
@click.option('--dry-run/--no-dry-run', default=False,
help='Show commands without executing')
def preprocess_synthetic_artifacts_circular(input_dir, output_dir, severities_json, seed, mvtec_format, dry_run):
"""Generate synthetic circular artifacts from clean source images.
Wraps dioodmi's circular artifact generator to create severity-based
datasets with optional MVTec-style folder layout.
Example:
hae preprocess synthetic-artifacts circular --input-dir ./clean --output-dir ./circ_out \\
--severities-json severities_circular_carpet.json --mvtec-format
"""
console.print("[bold cyan]🟢 Generating Synthetic Circular Artifacts[/bold cyan]\n")
cmd = [
"python", "-m", "dioodmi.data.synthetic_circular_artifact",
"--input_dir", str(input_dir),
"--output_dir", str(output_dir),
"--severities_json", str(severities_json)
]
if seed is not None:
cmd.extend(["--seed", str(seed)])
if mvtec_format:
cmd.append("--mvtec_format")
show_standalone_alternatives(cmd, "Generate synthetic circular artifact dataset")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.group()
def data():
"""Data manipulation and analysis utilities.
Commands for processing, transforming, and analyzing anomaly detection datasets.
"""
pass
@data.command('gen-masks')
@click.option('--latent-pad', default=None, help='Padding tuple for latent space')
@click.option('--ddpm-checkpoint-epoch', default=None, help='Specific checkpoint epoch')
@click.option('--reconstruct-ids', default=None, help='File with test image path')
@click.option('--output-dir', default='./results', type=click.Path(), help='Output directory')
@click.option('--model-name', default='oneimage', help='Model run name')
@click.option('--augmentation', type=int, default=0, help='Use augmentation')
@click.option('--num-workers', type=int, default=0, help='Number of data loading workers')
@click.option('--cache-data', type=int, default=0, help='Cache dataset in memory')
@click.option('--drop-last', type=int, default=0, help='Drop last incomplete batch')
@click.option('--first-n', default=None, help='Process only first N samples')
@click.option('--is-grayscale', type=int, default=0, help='1 for grayscale images')
@click.option('--spatial-dimension', type=int, default=2, help='2D or 3D data')
@click.option('--image-size', type=int, default=None, help='Target image size')
@click.option('--image-roi', default=None, help='Region of interest')
@click.option('--vqvae-checkpoint', default=None, help='VQVAE checkpoint path')
@click.option('--model-type', default='small', help='Model size variant')
@click.option('--scheduler', default='ddpm', help='Diffusion scheduler')
@click.option('--prediction-type', default='epsilon', help='Prediction type')
@click.option('--beta-schedule', default='linear_beta', help='Beta schedule')
@click.option('--beta-start', type=float, default=1e-4, help='Beta start value')
@click.option('--beta-end', type=float, default=2e-2, help='Beta end value')
@click.option('--b-scale', type=float, default=1.0, help='Beta scale factor')
@click.option('--snr-shift', type=float, default=1.0, help='SNR shift value')
@click.option('--simplex-noise', type=int, default=0, help='Use simplex noise')
@click.option('--show-diff', type=int, default=0, help='Show difference images')
@click.option('--batch-size', type=int, default=10, help='Batch size')
@click.option('--num-train-timesteps', type=int, default=1000, help='Number of diffusion steps')
@click.option('--t-destruct', type=int, default=200, help='Destruction timestep')
@click.option('--t-inpaint', type=int, default=50, help='Inpainting timestep')
@click.option('--resample-steps', type=int, default=4, help='Resampling steps')
@click.option('--binary-threshold', type=int, default=50, help='Binary mask threshold')
@click.option('--detect-threshold', type=float, default=10, help='Detection threshold')
@click.option('--anomalymap-threshold', type=int, default=40, help='Anomaly map threshold')
@click.option('--masking-threshold', type=int, default=-1, help='Masking threshold')
@click.option('--mask-percentile', type=float, default=95, help='Mask percentile')
@click.option('--kernel-size', type=int, default=3, help='Morphological kernel size')
@click.option('--brightness-scale', type=int, default=0, help='Brightness scaling')
@click.option('--mask-color', default='black', help='Mask color')
@click.option('--mask-anomaly-map', type=int, default=0, help='Generate masked anomaly maps')
@click.option('--detect-anomaly-map', type=int, default=0, help='Generate detection maps')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def data_gen_masks(**kwargs):
"""Generate binary masks from anomaly maps.
Uses trained diffusion models to reconstruct images and generate
binary masks highlighting anomalous regions.
Examples:
hae data gen-masks --model-name bottle_model --output-dir ./masks
hae data gen-masks --reconstruct-ids test.txt --binary-threshold 40
"""
console.print("[cyan]🎭 Generating Binary Masks[/cyan]\n")
# Build command - delegate to original script
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "gen_mask_images.py"
cmd = ["python", str(script_path)]
# Convert kwargs to command-line arguments
for key, value in kwargs.items():
if key == 'dry_run':
continue
if value is not None and value != '' and value is not False:
# Convert underscore to dash
arg_name = key.replace('_', '-')
if isinstance(value, bool):
if value:
cmd.append(f"--{arg_name}")
else:
cmd.extend([f"--{arg_name}", str(value)])
show_standalone_alternatives(cmd, "Generate masks from anomaly maps")
if kwargs.get('dry_run'):
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@data.command('save')
@click.argument('args', nargs=-1, type=click.UNPROCESSED)
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def data_save(args, dry_run):
"""Save reconstructed images from diffusion model.
This command wraps save_images.py with all its original arguments.
Run 'hae data save --help' to see options or use --dry-run to preview.
Examples:
hae data save --model-name bottle_model --output-dir ./output
hae data save --dry-run # Show equivalent command
"""
console.print("[cyan]💾 Saving Reconstructed Images[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "save_images.py"
cmd = ["python", str(script_path)] + list(args)
show_standalone_alternatives(cmd, "Save reconstructed images")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@data.command('save-pairs')
@click.argument('args', nargs=-1, type=click.UNPROCESSED)
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def data_save_pairs(args, dry_run):
"""Save pairs of original and masked images.
This command wraps save_masked_pairs.py with all its original arguments.
Examples:
hae data save-pairs --src-dir ./images --dst ./pairs
hae data save-pairs --dry-run # Show equivalent command
"""
console.print("[cyan]🖼️ Saving Original/Masked Image Pairs[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "save_masked_pairs.py"
cmd = ["python", str(script_path)] + list(args)
show_standalone_alternatives(cmd, "Save original/masked image pairs")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@data.command('compare-contours')
@click.option('--gt', required=True, type=click.Path(exists=True), help='Ground-truth JSONL file')
@click.option('--pred', required=True, type=click.Path(exists=True), help='Prediction JSONL file')
@click.option('--pixth', type=int, default=4, help='Min pixel intersection threshold')
@click.option('--out-dir', required=True, type=click.Path(), help='Output directory')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def data_compare_contours(gt, pred, pixth, out_dir, dry_run):
"""Compare predicted contours against ground truth.
Evaluates contour prediction accuracy by computing pixel-wise
intersection between ground truth and predicted anomaly contours.
Examples:
hae data compare-contours --gt truth.jsonl --pred pred.jsonl --out-dir ./eval
hae data compare-contours --gt truth.jsonl --pred pred.jsonl --pixth 10 --out-dir ./eval
"""
console.print("[cyan]📊 Comparing Contours[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "compare_contours.py"
cmd = [
"python", str(script_path),
"--gt", gt,
"--pred", pred,
"--pixth", str(pixth),
"--out_dir", out_dir
]
show_standalone_alternatives(cmd, "Compare predicted vs ground-truth contours")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@data.command('copy-from-csv')
@click.argument('csv_file', type=click.Path(exists=True))
@click.argument('output_dir', type=click.Path())
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def data_copy_from_csv(csv_file, output_dir, dry_run):
"""Copy images listed in CSV to output directory.
Reads a CSV file containing image paths and copies all listed
images to the specified output directory.
Examples:
hae data copy-from-csv data/test.csv ./test_images
hae data copy-from-csv data/train.csv ./train_images --dry-run
"""
console.print("[cyan]📁 Copying Images from CSV[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "copy-images-in-csv.sh"
cmd = ["bash", str(script_path), csv_file, output_dir]
show_standalone_alternatives(cmd, "Copy images listed in CSV")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.group()
def model():
"""Model and checkpoint management utilities.
Commands for inspecting checkpoints, managing model files, and checkpoint utilities.
"""
pass
@model.command('info')
@click.argument('checkpoint_path', type=click.Path(exists=True))
@click.option('--json', 'output_json', is_flag=True, help='Output in JSON format')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def model_info(checkpoint_path, output_json, dry_run):
"""Display checkpoint information (epoch, loss, parameters).
Inspects a checkpoint file and displays its metadata including
epoch number, loss values, parameter count, and file size.
Examples:
hae model info ./checkpoints/best.pt
hae model info ./checkpoints/best.pt --json
"""
console.print("[cyan]📦 Checkpoint Information[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "checkpoint_info.py"
cmd = ["python", str(script_path), checkpoint_path]
if output_json:
cmd.append("--json")
show_standalone_alternatives(cmd, "Inspect checkpoint file")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@model.command('utils')
@click.argument('path', type=click.Path(exists=True))
@click.option('--list', 'list_mode', is_flag=True, help='List all checkpoints in directory')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def model_utils(path, list_mode, dry_run):
"""Checkpoint inspection utilities.
Inspect a single checkpoint file or list all checkpoints in a directory
with their epoch, loss, and size information.
Examples:
hae model utils ./checkpoints/model.pt
hae model utils ./checkpoints --list
"""
console.print("[cyan]🔧 Checkpoint Utilities[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "checkpoint_utils.py"
cmd = ["python", str(script_path), path]
if list_mode:
cmd.append("--list")
show_standalone_alternatives(cmd, "Inspect checkpoint utilities")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@model.command('save-ddpm')
@click.argument('args', nargs=-1, type=click.UNPROCESSED)
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def model_save_ddpm(args, dry_run):
"""Save original DDPM model format.
Converts and saves DDPM models in original format.
This command wraps orig_ddpm_save.py with all its original arguments.
Examples:
hae model save-ddpm --model-name bottle --output-dir ./models
hae model save-ddpm --dry-run # Show equivalent command
"""
console.print("[cyan]💾 Saving DDPM Model[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "orig_ddpm_save.py"
cmd = ["python", str(script_path)] + list(args)
show_standalone_alternatives(cmd, "Save DDPM model in original format")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@model.command('position-decoder')
@click.argument('args', nargs=-1, type=click.UNPROCESSED)
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def model_position_decoder(args, dry_run):
"""Decode position-encoded patches.
Utility for decoding position information from position-encoded image patches.
Provides functions for extracting crop coordinates and image IDs from patches.
Examples:
hae model position-decoder --help # Show available functions
hae model position-decoder --dry-run
"""
console.print("[cyan]🔍 Position Decoder Utility[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "scripts" / "position_decoder.py"
cmd = ["python", str(script_path)] + list(args)
show_standalone_alternatives(cmd, "Decode position-encoded patches")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print("[bold]Running:[/bold]", " ".join(cmd))
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.group()
def system():
"""System diagnostics and health checks.
Commands for checking system configuration, dependencies, and environment setup.
"""
pass
@system.command('check')
@click.option('--verbose', '-v', is_flag=True, help='Show detailed information')
@click.option('--json', 'output_json', is_flag=True, help='Output in JSON format')
def system_check(verbose, output_json):
"""Comprehensive system health check.
Checks Python environment, CUDA/GPU setup, dependencies, disk space,
and DiOodMi project configuration.
Examples:
hae system check
hae system check --verbose
hae system check --json
"""
import sys
import platform
import shutil
import json as json_lib
console.print("[bold cyan]🔍 DiOodMi System Health Check[/bold cyan]\n")
checks = {}
all_pass = True
# Python Environment
console.print("[bold]Python Environment:[/bold]")
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
py_pass = sys.version_info >= (3, 10)
checks['python'] = {'version': py_version, 'pass': py_pass}
console.print(f" {'✓' if py_pass else '✗'} Python {py_version} {'' if py_pass else '(requires >=3.10)'}")
if verbose:
console.print(f" Executable: {sys.executable}")
console.print(f" Platform: {platform.platform()}")
# PyTorch and CUDA
console.print("\n[bold]PyTorch & CUDA:[/bold]")
try:
import torch
torch_version = torch.__version__
cuda_available = torch.cuda.is_available()
cuda_version = torch.version.cuda if cuda_available else "N/A"
num_gpus = torch.cuda.device_count() if cuda_available else 0
checks['pytorch'] = {
'version': torch_version,
'cuda_available': cuda_available,
'cuda_version': cuda_version,
'num_gpus': num_gpus,
'pass': True
}
console.print(f" ✓ PyTorch {torch_version}")
console.print(f" {'✓' if cuda_available else '✗'} CUDA available: {cuda_available}")
if cuda_available:
console.print(f" ✓ CUDA version: {cuda_version}")
console.print(f" ✓ GPUs detected: {num_gpus}")
if verbose:
for i in range(num_gpus):
name = torch.cuda.get_device_name(i)
props = torch.cuda.get_device_properties(i)
console.print(f" GPU {i}: {name}")
console.print(f" Compute: {props.major}.{props.minor}, Memory: {props.total_memory / 1024**3:.1f}GB")
except ImportError:
checks['pytorch'] = {'pass': False, 'error': 'PyTorch not installed'}
console.print(" ✗ PyTorch not installed")
all_pass = False
# Key Dependencies
console.print("\n[bold]Key Dependencies:[/bold]")
deps = [
('monai', 'MONAI'),
('diffusers', 'Diffusers'),
('transformers', 'Transformers'),
('click', 'Click'),
('rich', 'Rich'),
]
checks['dependencies'] = {}
for module, name in deps:
try:
mod = __import__(module)
version = getattr(mod, '__version__', 'unknown')
console.print(f" ✓ {name} {version}")
checks['dependencies'][module] = {'version': version, 'pass': True}
except ImportError:
console.print(f" ✗ {name} not installed")
checks['dependencies'][module] = {'pass': False, 'error': 'not installed'}
all_pass = False
# Disk Space
console.print("\n[bold]Disk Space:[/bold]")
project_root = Path(__file__).parent.parent.parent.parent.parent
total, used, free = shutil.disk_usage(project_root)
free_gb = free / (1024**3)
disk_pass = free_gb > 10 # Warn if less than 10GB free
checks['disk'] = {
'free_gb': round(free_gb, 2),
'total_gb': round(total / (1024**3), 2),
'pass': disk_pass
}
console.print(f" {'✓' if disk_pass else '⚠'} Free space: {free_gb:.1f} GB {'' if disk_pass else '(low disk space)'}")
if verbose:
console.print(f" Total: {total / (1024**3):.1f} GB")
console.print(f" Used: {used / (1024**3):.1f} GB")
# Project Structure
console.print("\n[bold]Project Structure:[/bold]")
critical_paths = [
('project/', 'Core codebase'),
('workspaces/hae-py/', 'HAE automation'),
('tests/', 'Test suite'),
('.git/', 'Git repository'),
]
checks['project_structure'] = {}
for path, desc in critical_paths:
full_path = project_root / path
exists = full_path.exists()
console.print(f" {'✓' if exists else '✗'} {desc}: {path}")
checks['project_structure'][path] = {'exists': exists, 'pass': exists}
if not exists:
all_pass = False
# Summary
console.print()
if all_pass:
console.print("[bold green]✓ All checks passed![/bold green]")
else:
console.print("[bold yellow]⚠ Some checks failed - see details above[/bold yellow]")
checks['summary'] = {'all_pass': all_pass}
if output_json:
console.print("\n[bold]JSON Output:[/bold]")
print(json_lib.dumps(checks, indent=2))
sys.exit(0 if all_pass else 1)
@main.group()
def verify():
"""Regression testing and baseline management.
Commands for managing golden reference baselines, comparing evaluation
outputs, and running regression tests to detect when code changes affect results.
"""
pass
@verify.command('create-baseline')
@click.option('--version', required=True, help='Baseline version (e.g., v1.0, v1.1)')
@click.option('--method', type=click.Choice(['method1', 'method2', 'both']),
default='both', help='Which evaluation method(s) to run')
@click.option('--description', default='', help='Description of this baseline')
@click.option('--dataset', default='mvtec', help='Dataset name')
@click.option('--data-dir', default='./mvtec-dataset/', help='Data directory path')
@click.option('--object-class', default='wood', help='Object class')
@click.option('--anomaly-class', default='hole', help='Anomaly class')
@click.option('--model-size', default='UNet_L', help='Model size')
@click.option('--patch-size', type=int, default=1024, help='Patch size')
@click.option('--pretrained', default='./results/decodiff_mvtec/checkpoints/MVTEC-AD-model.pt',
help='Pretrained model path')
@click.option('--annotation-dir', default='./annotations_wood_test_hole',
help='Annotation directory')
@click.option('--reverse-steps', type=int, default=5, help='Reverse diffusion steps')
@click.option('--batch-size', type=int, default=1, help='Batch size')
@click.option('--vae-type', default='ema', help='VAE type')
@click.option('--image-size', type=int, default=1024, help='Image size')
@click.option('--center-size', type=int, default=1024, help='Center crop size')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def verify_create_baseline(version, method, description, dataset, data_dir, object_class,
anomaly_class, model_size, patch_size, pretrained, annotation_dir,
reverse_steps, batch_size, vae_type, image_size, center_size,
dry_run):
"""Generate golden reference baseline for regression testing.
This command runs evaluation methods (decodiff_evaluator and/or
decodiff_evaluate_process), collects their outputs (NPY, JSON, confusion
matrices), creates a versioned baseline directory, and prepares it for DVC tracking.
Examples:
hae verify create-baseline --version v1.0 --method both
hae verify create-baseline --version v1.1 --method method1
hae verify create-baseline --version v1.1 --object-class bottle --anomaly-class crack
"""
console.print("[cyan]📦 Generating Golden Reference Baseline[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "scripts" / "generate_baseline.py"
cmd = [
"python", str(script_path),
"--version", version,
"--method", method,
"--description", description,
"--dataset", dataset,
"--data-dir", data_dir,
"--object-class", object_class,
"--anomaly-class", anomaly_class,
"--model-size", model_size,
"--patch-size", str(patch_size),
"--pretrained", pretrained,
"--annotation-dir", annotation_dir,
"--reverse-steps", str(reverse_steps),
"--batch-size", str(batch_size),
"--vae-type", vae_type,
"--image-size", str(image_size),
"--center-size", str(center_size)
]
show_standalone_alternatives(cmd, "Generate golden reference baseline")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
result = subprocess.run(cmd)
sys.exit(result.returncode)
@verify.command('approve-baseline')
@click.option('--version', required=True, help='Baseline version to approve (e.g., v1.1)')
@click.option('--auto-approve', is_flag=True, help='Skip approval confirmation')
@click.option('--skip-dvc', is_flag=True, help='Skip DVC tracking (for testing)')
@click.option('--skip-git', is_flag=True, help='Skip git commit (for testing)')
@click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing')
def verify_approve_baseline(version, auto_approve, skip_dvc, skip_git, dry_run):
"""Approve and finalize a new golden reference baseline.
This command compares the candidate baseline with the current baseline,
shows differences and statistics, prompts for approval, updates the
'current' symlink, tracks the baseline with DVC, and commits changes to git.
Examples:
hae verify approve-baseline --version v1.1
hae verify approve-baseline --version v1.1 --auto-approve
hae verify approve-baseline --version v1.1 --skip-dvc --skip-git
"""
console.print("[cyan]✅ Approving Golden Reference Baseline[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
script_path = project_root / "scripts" / "approve_baseline.py"
cmd = ["python", str(script_path), "--version", version]
if auto_approve:
cmd.append("--auto-approve")
if skip_dvc:
cmd.append("--skip-dvc")
if skip_git:
cmd.append("--skip-git")
show_standalone_alternatives(cmd, "Approve and finalize baseline")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
result = subprocess.run(cmd)
sys.exit(result.returncode)
@verify.command('compare')
@click.argument('baseline_dir', type=click.Path(exists=True))
@click.argument('candidate_dir', type=click.Path(exists=True))
@click.option('--method', type=click.Choice(['method1', 'method2', 'both']),
default='both', help='Which method outputs to compare')
@click.option('--verbose', '-v', is_flag=True, help='Show detailed comparison results')
@click.option('--json', 'output_json', is_flag=True, help='Output results in JSON format')
@click.option('--tolerance', type=float, help='Custom tolerance for numerical comparisons')
def verify_compare(baseline_dir, candidate_dir, method, verbose, output_json, tolerance):
"""Compare two baseline directories.
Compares outputs (NPY arrays, JSON metrics, confusion matrices) between
two baseline directories and shows detailed differences. Useful for
reviewing changes before approving a new baseline.
Examples:
hae verify compare tests/golden/v1.0 tests/golden/v1.1
hae verify compare tests/golden/v1.0 tests/golden/v1.1 --method method1 --verbose
hae verify compare tests/golden/v1.0 tests/golden/v1.1 --json
"""
console.print("[cyan]🔍 Comparing Baselines[/cyan]\n")
baseline_path = Path(baseline_dir)
candidate_path = Path(candidate_dir)
console.print(f"[bold]Baseline:[/bold] {baseline_path}")
console.print(f"[bold]Candidate:[/bold] {candidate_path}\n")
# Import comparison utilities
project_root = Path(__file__).parent.parent.parent.parent.parent
sys.path.insert(0, str(project_root))
try:
from project.utils.comparison import (
compare_directories,
summarize_comparison_results,
ToleranceConfig
)
except ImportError as e:
console.print(f"[red]✗ Failed to import comparison utilities: {e}[/red]")
sys.exit(1)
# Configure tolerance
if tolerance:
tolerance_config = ToleranceConfig(
npy_atol=tolerance,
npy_rtol=tolerance,
metrics_atol=tolerance,
metrics_rtol=tolerance
)
else:
tolerance_config = ToleranceConfig()
all_pass = True
# Compare Method 1
if method in ['method1', 'both']:
console.print("[bold]Method 1 Comparison:[/bold]")
console.print("-" * 80)
method1_baseline = baseline_path / "method1"
method1_candidate = candidate_path / "method1"
if method1_baseline.exists() and method1_candidate.exists():
results = compare_directories(
method1_baseline,
method1_candidate,
tolerance_config
)
summary = summarize_comparison_results(results, verbose=verbose)
console.print(summary)
# Check if any comparison failed
for result in results:
if not result.passed:
all_pass = False
else:
console.print("[yellow]⚠ Method 1 directories not found in one or both baselines[/yellow]")
all_pass = False
console.print()
# Compare Method 2
if method in ['method2', 'both']:
console.print("[bold]Method 2 Comparison:[/bold]")
console.print("-" * 80)
method2_baseline = baseline_path / "method2"
method2_candidate = candidate_path / "method2"
if method2_baseline.exists() and method2_candidate.exists():
results = compare_directories(
method2_baseline,
method2_candidate,
tolerance_config
)
summary = summarize_comparison_results(results, verbose=verbose)
console.print(summary)
# Check if any comparison failed
for result in results:
if not result.passed:
all_pass = False
else:
console.print("[yellow]⚠ Method 2 directories not found in one or both baselines[/yellow]")
all_pass = False
console.print()
# Summary
if all_pass:
console.print("[bold green]✓ All comparisons passed![/bold green]")
sys.exit(0)
else:
console.print("[bold yellow]⚠ Some comparisons failed - see details above[/bold yellow]")
sys.exit(1)
@verify.command('list')
@click.option('--verbose', '-v', is_flag=True, help='Show detailed information')
@click.option('--json', 'output_json', is_flag=True, help='Output in JSON format')
def verify_list(verbose, output_json):
"""List available golden reference baselines.
Shows all versioned baselines in tests/golden/ directory with metadata
information (version, date, description, git commit).
Examples:
hae verify list
hae verify list --verbose
hae verify list --json
"""
console.print("[cyan]📋 Available Golden Reference Baselines[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
golden_dir = project_root / "tests" / "golden"
if not golden_dir.exists():
console.print("[yellow]⚠ Golden references directory not found[/yellow]")
console.print(f"Expected: {golden_dir}")
return
# Find version directories
version_dirs = []
for item in golden_dir.iterdir():
if item.is_dir() and item.name.startswith('v'):
version_dirs.append(item)
if not version_dirs:
console.print("[yellow]⚠ No baselines found[/yellow]")
console.print(f"\nGenerate a baseline with: hae verify create-baseline --version v1.0")
return
# Sort by version number
def parse_version(name):
try:
parts = name[1:].split('.')
return tuple(int(p) for p in parts)
except ValueError:
return (0, 0, 0)
version_dirs.sort(key=lambda p: parse_version(p.name), reverse=True)
# Find current baseline
current_link = golden_dir / "current"
current_version = None
if current_link.exists():
if current_link.is_symlink():
current_version = current_link.resolve().name
elif current_link.is_dir():
current_version = "current"
# Display baselines
import json as json_lib
baselines = []
for version_dir in version_dirs:
metadata_path = version_dir / "metadata.json"
baseline_info = {
'version': version_dir.name,
'current': version_dir.name == current_version,
'path': str(version_dir)
}
if metadata_path.exists():
try:
with open(metadata_path, 'r') as f:
metadata = json_lib.load(f)
baseline_info['date'] = metadata.get('date', 'unknown')
baseline_info['description'] = metadata.get('description', '')
baseline_info['git_commit'] = metadata.get('git_commit', 'unknown')[:8]
baseline_info['git_branch'] = metadata.get('git_branch', 'unknown')
except Exception as e:
baseline_info['error'] = str(e)
baselines.append(baseline_info)
# Output
if output_json:
console.print(json_lib.dumps(baselines, indent=2))
else:
table = Table(title="Golden Reference Baselines")
table.add_column("Version", style="cyan")
table.add_column("Current", style="green")
table.add_column("Date", style="yellow")
table.add_column("Description")
table.add_column("Git Commit", style="magenta")
for baseline in baselines:
current_marker = "✓" if baseline['current'] else ""
date = baseline.get('date', 'unknown')
if date != 'unknown' and 'T' in date:
date = date.split('T')[0] # Show only date part
table.add_row(
baseline['version'],
current_marker,
date,
baseline.get('description', ''),
baseline.get('git_commit', 'unknown')
)
console.print(table)
if verbose:
console.print("\n[bold]Details:[/bold]")
for baseline in baselines:
console.print(f"\n[cyan]{baseline['version']}[/cyan]")
console.print(f" Path: {baseline['path']}")
if 'git_branch' in baseline:
console.print(f" Branch: {baseline['git_branch']}")
if 'error' in baseline:
console.print(f" [red]Error: {baseline['error']}[/red]")
console.print(f"\n[dim]Use 'hae verify compare' to compare baselines[/dim]")
@verify.command('run')
@click.option('--baseline', required=True, help='Baseline version to test against (e.g., v1.0)')
@click.option('--threshold', type=float, default=1e-5,
help='MAE threshold for pass/fail (default: 1e-5)')
@click.option('--full', is_flag=True,
help='Run full comparison including cross-method validation')
@click.option('--visualize', is_flag=True, default=True,
help='Generate visualizations (default: enabled)')
@click.option('--dry-run/--no-dry-run', default=False,
help='Show command without executing')
@click.option('--use-script', is_flag=True,
help='Use legacy run_regression_test.py script (fallback mode)')
def verify_run(baseline, threshold, full, visualize, dry_run, use_script):
"""Run regression test suite comparing against golden baseline.
By default, runs both decodiff_evaluator and anomaly_detector vs baseline (quick mode).
Use --full for comprehensive testing including cross-method validation.
Uses pipeline-based experiment system for declarative, reproducible testing.
Test Flow (Quick Mode - Default):
[1/6] Create test environment
[2/6] Run decodiff_evaluator evaluation (hae eval)
[3/6] Run anomaly_detector evaluation (hae anomaly-detect)
[4/6] Compare decodiff_evaluator vs baseline
[5/6] Compare anomaly_detector vs baseline
[6/6] Generate quick mode summary
Test Flow (Full Mode - --full):
[1/7] Create test environment
[2/7] Run decodiff_evaluator evaluation
[3/7] Run anomaly_detector evaluation
[4/7] Cross-validate decodiff_evaluator vs anomaly_detector
[5/7] Compare decodiff_evaluator vs baseline
[6/7] Compare anomaly_detector vs baseline
[7/7] Generate full mode summary
Examples:
# Quick test (both evaluators vs baseline, ~2-3 minutes)
hae verify run --baseline v1.0
# Full test (includes cross-validation, ~5-10 minutes)
hae verify run --baseline v1.0 --full
# Custom threshold
hae verify run --baseline v1.0 --threshold 1e-6
# Dry run to see what will execute
hae verify run --baseline v1.0 --dry-run
# Use legacy script (fallback)
hae verify run --baseline v1.0 --use-script
"""
console.print("[cyan]🧪 Running Regression Test Suite[/cyan]\n")
project_root = Path(__file__).parent.parent.parent.parent.parent
# Determine mode and experiment name
if use_script:
# Legacy mode: use run_regression_test.py script
tests_dir = project_root / "workspaces" / "hae-py" / "tests"
script_path = tests_dir / "run_regression_test.py"
if not script_path.exists():
console.print(f"[red]✗ Regression test script not found: {script_path}[/red]")
sys.exit(1)
# Build command
cmd = [
"py", "-3.11", str(script_path),
"--baseline", baseline,
"--threshold", str(threshold)
]
if full:
cmd.append("--full")
if visualize:
cmd.append("--visualize")
# Show standalone command
show_standalone_alternatives(cmd, "Run regression test suite (legacy script)")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
# Execute regression test
console.print(f"[bold]Mode:[/bold] {'Full Comparison' if full else 'Quick Test'} (legacy script)")
console.print(f"[bold]Baseline:[/bold] {baseline}")
console.print(f"[bold]Threshold:[/bold] {threshold:.0e}")
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
else:
# New mode: use pipeline-based experiments
experiment_name = 'regression-test-full' if full else 'regression-test-quick'
# Build hae experiment run command
cmd = [
"hae", "experiment", "run", experiment_name,
"--variables", f"baseline_dir=./fixtures/golden/{baseline}",
"--variables", f"threshold={threshold}",
"--variables", f"visualize={'true' if visualize else 'false'}"
]
# Show standalone command
show_standalone_alternatives(cmd, "Run regression test suite (pipeline-based)")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
console.print(f"\n[bold]Would run experiment:[/bold] {experiment_name}")
console.print(f"[bold]Variables:[/bold]")
console.print(f" baseline_version: {baseline}")
console.print(f" threshold: {threshold}")
console.print(f" visualize: {visualize}")
return
# Execute experiment
console.print(f"[bold]Mode:[/bold] {'Full Comparison' if full else 'Quick Test'} (pipeline-based)")
console.print(f"[bold]Experiment:[/bold] {experiment_name}")
console.print(f"[bold]Baseline:[/bold] {baseline}")
console.print(f"[bold]Threshold:[/bold] {threshold:.0e}")
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@verify.command('compare-fundamentals')
@click.option('--test-dir', required=True, help='Test output directory')
@click.option('--baseline-dir', required=True, help='Baseline reference directory')
@click.option('--output-dir', required=True, help='Comparison results output directory')
@click.option('--threshold', type=float, default=1e-5, help='MAE threshold for pass/fail')
@click.option('--visualize/--no-visualize', default=True, help='Generate visualization plots')
@click.option('--dry-run/--no-dry-run', default=False, help='Preview command without executing')
def verify_compare_fundamentals(test_dir, baseline_dir, output_dir, threshold, visualize, dry_run):
"""Compare fundamental NPY variables against baseline.
Compares x0, encoded, image_samples, and latent_samples between
test outputs and golden baseline, computing MAE, MSE, RMSE, max_diff,
and Pearson correlation metrics.
Examples:
hae verify compare-fundamentals \\
--test-dir ./outputs/decodiff_evaluator \\
--baseline-dir ./fixtures/golden/v1.0/decodiff_evaluator \\
--output-dir ./comparison/decodiff_evaluator
"""
console.print("[bold cyan]📊 Comparing Fundamental NPY Variables[/bold cyan]\n")
# Build command
cmd = [
"python", "workspaces/hae-py/tests/compare_fundamentals.py",
"--test_dir", test_dir,
"--baseline_dir", baseline_dir,
"--output_dir", output_dir,
"--threshold", str(threshold),
]
if visualize:
cmd.append("--visualize")
# Show standalone alternatives
show_standalone_alternatives(cmd, "Compare fundamental NPY variables")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print(f"[bold]Test directory:[/bold] {test_dir}")
console.print(f"[bold]Baseline directory:[/bold] {baseline_dir}")
console.print(f"[bold]Output directory:[/bold] {output_dir}")
console.print(f"[bold]Threshold:[/bold] MAE < {threshold:.2e}")
console.print(f"[bold]Visualize:[/bold] {visualize}")
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@verify.command('compare-cross-methods')
@click.option('--decodiff-evaluator-dir', 'decodiff_evaluator_dir', required=True,
help='decodiff_evaluator output directory')
@click.option('--anomaly-detector-dir', 'anomaly_detector_dir', required=True,
help='anomaly_detector output directory')
@click.option('--output-dir', required=True, help='Comparison results output directory')
@click.option('--threshold', type=float, default=1e-5, help='MAE threshold for pass/fail')
@click.option('--visualize/--no-visualize', default=True, help='Generate visualization plots')
@click.option('--dry-run/--no-dry-run', default=False, help='Preview command without executing')
def verify_compare_cross_methods(decodiff_evaluator_dir, anomaly_detector_dir, output_dir,
threshold, visualize, dry_run):
"""Cross-validate decodiff_evaluator vs anomaly_detector outputs.
Compares fundamental variables and derived anomaly maps between
both evaluation methods to ensure algorithmic equivalence.
Examples:
hae verify compare-cross-methods \\
--decodiff-evaluator-dir ./outputs/decodiff_evaluator \\
--anomaly-detector-dir ./outputs/anomaly_detector \\
--output-dir ./comparison/cross_method
"""
console.print("[bold cyan]🔄 Cross-Validating Evaluation Methods[/bold cyan]\n")
# Build command
cmd = [
"python", "workspaces/hae-py/tests/compare_cross_methods.py",
"--decodiff_evaluator_dir", decodiff_evaluator_dir,
"--anomaly_detector_dir", anomaly_detector_dir,
"--output_dir", output_dir,
"--threshold", str(threshold),
]
if visualize:
cmd.append("--visualize")
# Show standalone alternatives
show_standalone_alternatives(cmd, "Cross-validate evaluation methods")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print(f"[bold]decodiff_evaluator directory:[/bold] {decodiff_evaluator_dir}")
console.print(f"[bold]anomaly_detector directory:[/bold] {anomaly_detector_dir}")
console.print(f"[bold]Output directory:[/bold] {output_dir}")
console.print(f"[bold]Threshold:[/bold] MAE < {threshold:.2e}")
console.print(f"[bold]Visualize:[/bold] {visualize}")
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@verify.command('summary')
@click.option('--mode', type=click.Choice(['quick', 'full']), required=True,
help='Summary mode: quick (baseline only) or full (with cross-validation)')
@click.option('--output-dir', required=True, help='Output directory for summary file')
@click.option('--baseline-version', required=True, help='Baseline version tested against')
@click.option('--decodiff-evaluator-results', 'decodiff_evaluator_results', required=True,
help='Path to decodiff_evaluator comparison results JSON')
@click.option('--anomaly-detector-results', 'anomaly_detector_results', required=True,
help='Path to anomaly_detector comparison results JSON')
@click.option('--cross-method-results', 'cross_method_results', default=None,
help='Path to cross-method comparison results JSON (full mode only)')
@click.option('--dry-run/--no-dry-run', default=False, help='Preview command without executing')
def verify_summary(mode, output_dir, baseline_version, decodiff_evaluator_results,
anomaly_detector_results, cross_method_results, dry_run):
"""Generate regression test summary report.
Creates a human-readable summary of regression test results,
including pass/fail status, timing information, and key metrics.
Examples:
# Quick mode summary
hae verify summary --mode quick \\
--output-dir ./outputs \\
--baseline-version v1.0 \\
--decodiff-evaluator-results ./comparison/decodiff_evaluator/comparison_results.json \\
--anomaly-detector-results ./comparison/anomaly_detector/comparison_results.json
# Full mode summary (with cross-validation)
hae verify summary --mode full \\
--output-dir ./outputs \\
--baseline-version v1.0 \\
--decodiff-evaluator-results ./comparison/decodiff_evaluator/comparison_results.json \\
--anomaly-detector-results ./comparison/anomaly_detector/comparison_results.json \\
--cross-method-results ./comparison/cross_method/cross_method_results.json
"""
console.print("[bold cyan]📋 Generating Regression Test Summary[/bold cyan]\n")
# Build command
cmd = [
"python", "workspaces/hae-py/tests/generate_regression_summary.py",
"--mode", mode,
"--output-dir", output_dir,
"--baseline-version", baseline_version,
"--decodiff-evaluator-results", decodiff_evaluator_results,
"--anomaly-detector-results", anomaly_detector_results,
]
if cross_method_results:
cmd.extend(["--cross-method-results", cross_method_results])
# Show standalone alternatives
show_standalone_alternatives(cmd, "Generate regression test summary")
if dry_run:
console.print("[yellow]Dry run - not executing[/yellow]")
return
console.print(f"[bold]Mode:[/bold] {mode}")
console.print(f"[bold]Output directory:[/bold] {output_dir}")
console.print(f"[bold]Baseline version:[/bold] {baseline_version}")
console.print(f"[bold]decodiff_evaluator results:[/bold] {decodiff_evaluator_results}")
console.print(f"[bold]anomaly_detector results:[/bold] {anomaly_detector_results}")
if cross_method_results:
console.print(f"[bold]Cross-method results:[/bold] {cross_method_results}")
console.print()
result = subprocess.run(cmd)
sys.exit(result.returncode)
@main.command('info')
def info():
"""Display DiOodMi project information and status.
Shows workspace structure, available models, datasets, and configuration.
"""
console.print("[bold]DiOodMi - HAEDOSA Automation Engine[/bold]\n")
# Project info
table = Table(title="Project Information")
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
table.add_row("Version", "0.3.0")
table.add_row("Architecture", "Export-based (hybrid)")
table.add_row("Core Verbs", "dioodmi-train, dioodmi-eval, dioodmi-reconstruct")
table.add_row("Python Package", "dioodmi")
table.add_row("Standards", "HAEDOSA MS04, MS07, MS12")
console.print(table)
# Available models
console.print("\n[bold]Available Models:[/bold]")
console.print(" • DDPM (Pixel-space diffusion)")
console.print(" • LDM (Latent diffusion with VQVAE)")
console.print(" • DeCo-Diff (Deviation prediction)")
console.print(" • INFD (Image Neural Field Diffusion)")
# Available datasets
console.print("\n[bold]Supported Datasets:[/bold]")
console.print(" • MVTec-AD (5,346 images, 15 categories)")
console.print(" • VisA (10,822 images, 12 categories)")
console.print(" • Custom (CSV-based format)")
# Quick commands
console.print("\n[bold]Quick Commands:[/bold]")
console.print(" [cyan]# Training & Evaluation[/cyan]")
console.print(" hae train --model decodiff --dataset mvtec --category bottle")
console.print(" hae eval --checkpoint models/bottle.pth --dataset mvtec")
console.print()
console.print(" [cyan]# OOD Detection[/cyan]")
console.print(" hae ood-detection --output-dir ./results --model-name decodiff_mvtec")
console.print(" hae ood-results --output-dir ./results --model-name decodiff_mvtec")
console.print(" hae ood-graph --in-csv results/in.csv --out-csv results/out.csv")
console.print()
console.print(" [cyan]# Data Preprocessing[/cyan]")
console.print(" hae preprocess crop --input-dir ./data --output-dir ./crops --patch-size 128 --patches-per-image 2")
console.print(" hae preprocess split ./data/images --valid-ratio 0.1")
console.print(" hae preprocess annotate -i ./images -o ./annotations")
console.print()
console.print(" [cyan]# Data Processing & Analysis[/cyan]")
console.print(" hae data gen-masks --model-name bottle --output-dir ./masks")
console.print(" hae data compare-contours --gt truth.jsonl --pred pred.jsonl --out-dir ./eval")
console.print(" hae data copy-from-csv data/test.csv ./test_images")
console.print()
console.print(" [cyan]# Model & Checkpoint Management[/cyan]")
console.print(" hae model info ./checkpoints/best.pt")
console.print(" hae model utils ./checkpoints --list")
console.print(" hae model save-ddpm --model-name bottle --output-dir ./models")
console.print()
console.print(" [cyan]# Development Tools[/cyan]")
console.print(" hae dev check-cuda")
console.print(" hae dev sort-interactive --pred-jsonl pred.jsonl --out-dir ./sorted")
console.print(" hae dev jupyter")
console.print()
console.print(" [cyan]# System Diagnostics[/cyan]")
console.print(" hae system check")
console.print(" hae system check --verbose")
console.print()
console.print(" [cyan]# Regression Testing & Baselines[/cyan]")
console.print(" hae verify list")
console.print(" hae verify run --baseline v1.0")
console.print(" hae verify create-baseline --version v1.0 --method both")
console.print(" hae verify compare tests/golden/v1.0 tests/golden/v1.1")
console.print(" hae verify approve-baseline --version v1.1")
@main.command()
@click.option('--full', is_flag=True, help='Full rebuild (clears and re-indexes everything)')
@click.option('--verify', is_flag=True, help='Verify database integrity')
def sync(full, verify):
"""Sync experiment database with YAML files and execution history.
Convenience shortcut for 'hae experiment sync'.
By default, performs smart incremental sync (only changed files).
Safe to run frequently (e.g., after git pull or running experiments).
The database can always be rebuilt from YAML files (ground truth).
Feel free to delete .hae/ directory - sync will recreate it.
Examples:
hae sync # Smart sync (fast, incremental)
hae sync --full # Full rebuild (slower)
hae sync --verify # Check database integrity
For more experiment management commands, see 'hae experiment --help'.
"""
# Call the experiment sync function directly
from click import Context
ctx = Context(experiment_sync)
ctx.invoke(experiment_sync, full=full, verify=verify)
@main.command()
@click.option('--migration-only', is_flag=True, help='Check only migration status (quick)')
@click.option('--experiments', is_flag=True, help='Check only experiment system health')
@click.option('--system', is_flag=True, help='Check only system environment (CUDA, datasets)')
@click.option('--fix', is_flag=True, help='Auto-fix issues where possible')
@click.option('--verbose', is_flag=True, help='Show detailed diagnostics')
def doctor(migration_only, experiments, system, fix, verbose):
"""Comprehensive health check for DiOodMi project.
Performs a complete diagnostic of the DiOodMi system including:
- Experiment system migration status and validation
- System environment (CUDA availability, datasets)
- Database integrity and synchronization
- File references and dependencies
This is the recommended command to run after pulling new changes.
Examples:
hae doctor # Full health check
hae doctor --migration-only # Quick migration check
hae doctor --experiments # Only experiment system
hae doctor --system # Only system environment
hae doctor --fix # Auto-repair issues
hae doctor --verbose # Detailed diagnostics
"""
import sys
from pathlib import Path
from rich.table import Table
console.print("[bold cyan]🏥 DiOodMi System Health Check[/bold cyan]\n")
# Track overall status
total_errors = 0
total_warnings = 0
# Determine which checks to run
run_all = not (migration_only or experiments or system)
run_experiments = run_all or migration_only or experiments
run_system_check = run_all or system
# ========================================================================
# 1. Experiment System Health Check
# ========================================================================
if run_experiments:
console.print("[bold]Experiment System:[/bold]")
# Import experiment doctor functionality
from hae.migration import (
check_migration_status,
find_old_path_references,
validate_yaml_syntax,
validate_dependencies,
validate_file_references,
fix_yaml_paths
)
from hae.db import ExperimentDB
errors = 0
warnings = 0
# 1.1 Migration Status
status = check_migration_status()
if status.old_definitions_exists:
console.print(f" [red]❌ OLD:[/red] experiments/definitions/ ({status.old_definitions_count} files)")
errors += 1
if status.old_history_exists:
console.print(f" [red]❌ OLD:[/red] experiments/history/ ({status.old_history_count} records)")
errors += 1
if status.new_definitions_exists:
console.print(f" [green]✅ NEW:[/green] definitions/ ({status.new_definitions_count} files)")
else:
console.print(f" [yellow]⚠️ NEW:[/yellow] definitions/ not found")
warnings += 1
if status.new_history_exists:
console.print(f" [dim] NEW: experiment-history/ ({status.new_history_count} records)[/dim]")
if status.needs_migration:
console.print(f" [cyan]💡 Action:[/cyan] Run 'hae experiment migrate'\n")
else:
console.print(f" [green]✓ Migration complete[/green]\n")
# If migration-only flag, stop here
if migration_only:
summary = f"Migration {'needed' if status.needs_migration else 'complete'}"
console.print(f"[dim]{summary}[/dim]")
sys.exit(1 if status.needs_migration else 0)
# 1.2 Experiment Definitions Validation
yaml_dir = Path("definitions/experiments")
if yaml_dir.exists():
# YAML syntax validation
yaml_files = list(yaml_dir.glob("*.yaml"))
valid_count = 0
invalid_files = []
for yaml_file in yaml_files:
result = validate_yaml_syntax(yaml_file)
if result.valid:
valid_count += 1
else:
invalid_files.append((yaml_file.name, result.error))
if invalid_files:
console.print(f" [red]❌ YAML Syntax ({valid_count}/{len(yaml_files)} valid)[/red]")
errors += len(invalid_files)
else:
console.print(f" [green]✅ YAML Syntax ({len(yaml_files)}/{len(yaml_files)} valid)[/green]")
# Dependency validation
dep_issues = validate_dependencies(yaml_dir)
if dep_issues:
console.print(f" [red]❌ Dependencies ({len(dep_issues)} broken refs)[/red]")
errors += len(dep_issues)
else:
console.print(f" [green]✅ Dependencies (no broken references)[/green]")
# File reference validation
file_issues = validate_file_references(yaml_dir)
if file_issues:
console.print(f" [yellow]⚠️ Missing Files ({len(file_issues)} warnings)[/yellow]")
warnings += len(file_issues)
else:
console.print(f" [green]✅ File References (all files found)[/green]")
# Old path references
path_issues = find_old_path_references(yaml_dir)
if path_issues:
console.print(f" [yellow]⚠️ Old Paths ({len(path_issues)} found)[/yellow]")
warnings += len(path_issues)
if fix:
console.print(f" [cyan]🔧 Fixing path references...[/cyan]")
fixed_count, actions = fix_yaml_paths(yaml_dir, dry_run=False)
console.print(f" [green]✓ Fixed {fixed_count} files[/green]")
warnings -= len(path_issues) # Remove warnings since we fixed them
# 1.3 Database Status
db_path = Path(".hae/experiments.db")
if not db_path.exists():
console.print(" [yellow]⚠️ .hae/experiments.db not found[/yellow]")
console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'")
warnings += 1
else:
try:
db = ExperimentDB()
def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0]
exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0]
if def_count == 0 and status.new_definitions_exists:
console.print(f" [yellow]⚠️ Database empty (0 definitions)[/yellow]")
console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'")
warnings += 1
else:
console.print(f" [green]✅ Database OK ({def_count} definitions, {exec_count} executions)[/green]")
db.close()
except Exception as e:
console.print(f" [red]❌ Database error: {e}[/red]")
errors += 1
# 1.4 Format Validation (experiment.org must be .org, not .md)
history_dir = Path("experiment-history")
if history_dir.exists():
# Check for .md files in experiment-history/ (should be .org)
md_files = list(history_dir.glob("**/*.md"))
# Filter to only experiment.md files (not other docs)
experiment_md_files = [f for f in md_files if f.name in ["experiment.md", "notes.md"]]
if experiment_md_files:
console.print(f" [red]❌ Format Violation ({len(experiment_md_files)} .md files in experiment-history/)[/red]")
for md_file in experiment_md_files[:3]: # Show first 3
console.print(f" • {md_file}")
if len(experiment_md_files) > 3:
console.print(f" ... and {len(experiment_md_files) - 3} more")
console.print(" [cyan]💡 Action:[/cyan] Experiment records MUST be .org (org-roam :ID: required)")
console.print(" [cyan]💡 See:[/cyan] docs/DOCUMENTATION-STANDARDS.org")
errors += len(experiment_md_files)
else:
console.print(f" [green]✅ Format Compliance (experiment records are .org)[/green]")
console.print()
total_errors += errors
total_warnings += warnings
# ========================================================================
# 2. System Environment Check
# ========================================================================
if run_system_check:
console.print("[bold]System Environment:[/bold]")
env_errors = 0
env_warnings = 0
# 2.1 CUDA availability
try:
import torch
if torch.cuda.is_available():
cuda_version = torch.version.cuda
device_count = torch.cuda.device_count()
device_name = torch.cuda.get_device_name(0) if device_count > 0 else "N/A"
console.print(f" [green]✅ CUDA {cuda_version} ({device_count} device(s): {device_name})[/green]")
else:
console.print(f" [yellow]⚠️ CUDA not available (CPU-only mode)[/yellow]")
env_warnings += 1
except ImportError:
console.print(f" [red]❌ PyTorch not installed[/red]")
env_errors += 1
# 2.2 Dataset availability
datasets_dir = Path("datasets")
if datasets_dir.exists():
# Count both raw and processed subdirectories
raw_dir = datasets_dir / "raw"
processed_dir = datasets_dir / "processed"
dataset_count = 0
if raw_dir.exists():
dataset_count += sum(1 for d in raw_dir.iterdir() if d.is_dir())
if processed_dir.exists():
dataset_count += sum(1 for d in processed_dir.iterdir() if d.is_dir())
console.print(f" [green]✅ Datasets ({dataset_count} found in datasets/)[/green]")
else:
console.print(f" [yellow]⚠️ datasets/ not found[/yellow]")
env_warnings += 1
# 2.3 Core dependencies
try:
import dioodmi
console.print(f" [green]✅ DiOodMi core package installed[/green]")
except ImportError:
console.print(f" [red]❌ DiOodMi core package not installed[/red]")
console.print(f" [cyan]💡 Action:[/cyan] Run 'uv pip install -e workspaces/dioodmi-py'")
env_errors += 1
console.print()
total_errors += env_errors
total_warnings += env_warnings
# ========================================================================
# 3. Summary
# ========================================================================
console.print("[bold]Summary:[/bold]")
if total_errors == 0 and total_warnings == 0:
console.print("[green]✅ All healthy - no issues found![/green]")
sys.exit(0)
elif total_errors == 0:
console.print(f"[yellow]⚠️ {total_warnings} warning(s) found[/yellow]")
sys.exit(0)
else:
console.print(f"[red]❌ {total_errors} error(s), {total_warnings} warning(s) found[/red]")
if not fix:
console.print("[dim]Run with --fix to attempt automatic repairs[/dim]")
sys.exit(1)
console.print()
console.print(" [cyan]# Utilities[/cyan]")
console.print(" hae examples list")
console.print(" hae dvc status")
console.print(" hae dvc exp show")
console.print("\n[dim]Use 'hae --help' for more information[/dim]")
@main.group()
def vision():
"""Vision processing tools.
Commands for image segmentation (SAM) and visual analysis (Ollama).
"""
pass
@vision.command('segment')
@click.argument('image', type=click.Path(exists=True))
@click.option('--model', type=click.Choice(['vit_b', 'vit_l', 'vit_h']),
default='vit_b', help='SAM model size (default: vit_b)')
@click.option('--checkpoint-dir', type=click.Path(),
help='Directory containing SAM checkpoints')
@click.option('--threshold', type=float, default=0.5,
help='IoU threshold for mask filtering (default: 0.5)')
@click.option('--points-per-side', type=int, default=32,
help='Points per side for segmentation (default: 32, higher=better quality)')
@click.option('--combine', type=click.Choice(['union', 'largest']), default='union',
help='Combine all masks or only largest (default: union)')
@click.option('--output-dir', type=click.Path(), default='./output',
help='Output directory for masks (default: ./output)')
@click.option('--save-binary/--no-binary', default=True,
help='Save binary mask (default: True)')
@click.option('--save-instance/--no-instance', default=False,
help='Save instance mask (default: False)')
@click.option('--max-coverage', type=float, default=95.0,
help='Exclude objects covering more than this % of image (default: 95)')
@click.option('--save-individual/--no-individual', default=False,
help='Save each object mask as separate file for debugging (default: False)')
@click.option('--save-overlay/--no-overlay', default=False,
help='Save colored masks overlaid on original image (default: False)')
def vision_segment(image, model, checkpoint_dir, threshold, points_per_side,
combine, output_dir, save_binary, save_instance, max_coverage, save_individual, save_overlay):
"""Segment objects in an image using SAM (Segment Anything Model).
Examples:
hae vision segment image.bmp
hae vision segment image.bmp --model vit_l --threshold 0.9
hae vision segment image.bmp --save-instance --output-dir ./masks
"""
from pathlib import Path
from .vision import segment_image
console.print("[bold cyan]🎯 SAM Image Segmentation[/bold cyan]\n")
try:
image_path = Path(image)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
console.print(f" Image: {image_path.name}")
console.print(f" Model: {model}")
console.print(f" Threshold: {threshold}")
console.print(f" Points per side: {points_per_side}\n")
with console.status("[bold green]Segmenting image..."):
result = segment_image(
image_path,
model_type=model,
checkpoint_dir=Path(checkpoint_dir) if checkpoint_dir else None,
points_per_side=points_per_side,
threshold=threshold,
output_binary=save_binary,
output_instance=save_instance,
combine=combine,
)
# Save masks
import cv2
image_name = image_path.stem
if save_binary and result.get('binary_mask') is not None:
binary_file = output_path / f"{image_name}_binary.png"
cv2.imwrite(str(binary_file), result['binary_mask'])
console.print(f" ✓ Binary mask: {binary_file}")
if save_instance and result.get('instance_mask') is not None:
# Filter out objects that cover too much of the image
total_pixels = result['image_shape'][0] * result['image_shape'][1]
filtered_masks = [
m for m in result['raw_masks']
if (m['area'] / total_pixels * 100) <= max_coverage
]
# Create colored version for visualization
from .vision.segment import create_colored_instance_mask
colored_mask = create_colored_instance_mask(filtered_masks, threshold)
instance_file = output_path / f"{image_name}_instance.png"
cv2.imwrite(str(instance_file), colored_mask)
console.print(f" ✓ Instance mask: {instance_file}")
excluded_count = len(result['raw_masks']) - len(filtered_masks)
if excluded_count > 0:
console.print(f" ℹ Excluded {excluded_count} object(s) covering >{max_coverage}% of image")
# Save individual masks for debugging
if save_individual and result['num_objects'] > 0:
import numpy as np
total_pixels = result['image_shape'][0] * result['image_shape'][1]
sorted_masks = sorted(result['raw_masks'], key=lambda x: x['area'], reverse=True)
console.print(f"\n [bold]Saving individual object masks...[/bold]")
for idx, mask in enumerate(sorted_masks):
seg = mask['segmentation']
# Create white mask on black background
individual_mask = np.zeros(seg.shape, dtype=np.uint8)
individual_mask[seg] = 255
coverage = (mask['area'] / total_pixels) * 100
individual_file = output_path / f"{image_name}_object{idx+1}_{coverage:.1f}pct.png"
cv2.imwrite(str(individual_file), individual_mask)
console.print(f" Object {idx+1}: {individual_file.name}")
# Save overlay visualization
if save_overlay and result['num_objects'] > 0:
import numpy as np
from PIL import Image
# Load original image
original = Image.open(image_path)
original_rgb = np.array(original.convert("RGB"))
# Filter masks by coverage
total_pixels = result['image_shape'][0] * result['image_shape'][1]
filtered_masks = [
m for m in result['raw_masks']
if (m['area'] / total_pixels * 100) <= max_coverage
]
# Create colored mask overlay
from .vision.segment import create_colored_instance_mask
colored_mask = create_colored_instance_mask(filtered_masks, threshold)
# Blend with original image (50% transparency)
alpha = 0.5
overlay = cv2.addWeighted(original_rgb, 1 - alpha, colored_mask, alpha, 0)
overlay_file = output_path / f"{image_name}_overlay.png"
cv2.imwrite(str(overlay_file), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
console.print(f" ✓ Overlay: {overlay_file}")
# Show object sizes for debugging
console.print(f"\n[bold green]✓ Found {result['num_objects']} objects[/bold green]")
if result['num_objects'] > 0:
total_pixels = result['image_shape'][0] * result['image_shape'][1]
console.print("\n[bold]Object details:[/bold]")
sorted_masks = sorted(result['raw_masks'], key=lambda x: x['area'], reverse=True)
for idx, mask in enumerate(sorted_masks[:10]): # Show top 10
area = mask['area']
coverage = (area / total_pixels) * 100
iou = mask.get('predicted_iou', 0)
excluded = " [dim](excluded)[/dim]" if coverage > max_coverage else ""
console.print(f" Object {idx+1}: {area:,} pixels ({coverage:.1f}% of image) - IoU: {iou:.3f}{excluded}")
except ImportError as e:
console.print(f"[bold red]✗ SAM not available[/bold red]")
console.print(f" Install with: pip install 'hae[vision]'")
sys.exit(1)
except Exception as e:
console.print(f"[bold red]✗ Segmentation failed: {e}[/bold red]")
sys.exit(1)
@vision.command('batch')
@click.argument('input_dir', type=click.Path(exists=True))
@click.option('--output-dir', type=click.Path(), default='./output',
help='Output directory for masks')
@click.option('--model', type=click.Choice(['vit_b', 'vit_l', 'vit_h']),
default='vit_b', help='SAM model size')
@click.option('--checkpoint-dir', type=click.Path(),
help='Directory containing SAM checkpoints')
@click.option('--threshold', type=float, default=0.5,
help='IoU threshold for mask filtering')
@click.option('--points-per-side', type=int, default=32,
help='Points per side for segmentation')
@click.option('--combine', type=click.Choice(['union', 'largest']), default='union',
help='Combine all masks or only largest')
@click.option('--save-binary/--no-binary', default=True,
help='Save binary masks')
@click.option('--save-instance/--no-instance', default=False,
help='Save instance masks')
@click.option('--recursive/--no-recursive', default=False,
help='Process subdirectories')
@click.option('--limit', type=int, help='Limit number of images to process')
def vision_batch(input_dir, output_dir, model, checkpoint_dir, threshold, points_per_side,
combine, save_binary, save_instance, recursive, limit):
"""Batch process images with SAM segmentation.
Fast GPU-optimized processing of multiple images.
Examples:
hae vision batch ./dataset
hae vision batch ./dataset --limit 100 --points-per-side 16
hae vision batch ./dataset --save-instance --recursive
"""
from pathlib import Path
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from .vision import batch_segment_images
console.print("[bold cyan]🎯 SAM Batch Segmentation[/bold cyan]\n")
try:
input_path = Path(input_dir)
output_path = Path(output_dir)
console.print(f" Input: {input_path}")
console.print(f" Output: {output_path}")
console.print(f" Model: {model}")
console.print(f" Threshold: {threshold}")
console.print(f" Points per side: {points_per_side}")
if limit:
console.print(f" Limit: {limit} images")
console.print()
# Get progress bar columns
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console
) as progress:
task = progress.add_task("Processing images...", total=None)
stats = batch_segment_images(
input_path,
output_path,
model_type=model,
checkpoint_dir=Path(checkpoint_dir) if checkpoint_dir else None,
points_per_side=points_per_side,
threshold=threshold,
combine=combine,
save_binary=save_binary,
save_instance=save_instance,
recursive=recursive,
limit=limit,
)
# Print summary
console.print()
console.print(f"[bold green]✓ Batch processing complete[/bold green]")
console.print(f" Total: {stats['total']}")
console.print(f" Successful: {stats['successful']}")
if stats['failed'] > 0:
console.print(f" Failed: {stats['failed']}")
console.print(f" Total objects: {stats['total_objects']}")
console.print(f" Output: {output_path}")
except ImportError:
console.print(f"[bold red]✗ SAM not available[/bold red]")
console.print(f" Install with: pip install 'hae[vision]'")
sys.exit(1)
except Exception as e:
console.print(f"[bold red]✗ Batch processing failed: {e}[/bold red]")
import traceback
traceback.print_exc()
sys.exit(1)
@vision.command('analyze')
@click.argument('image', type=click.Path(exists=True))
@click.option('--prompt', default='Describe all objects in this image in detail',
help='Prompt for the vision model')
@click.option('--model', default='llama3.2-vision:11b',
help='Ollama vision model name')
@click.option('--host', default='localhost',
help='Ollama server host')
@click.option('--port', type=int, default=11434,
help='Ollama server port')
def vision_analyze(image, prompt, model, host, port):
"""Analyze an image using Ollama vision models.
Examples:
hae vision analyze image.bmp
hae vision analyze image.bmp --prompt "Describe any abnormalities"
hae vision analyze image.bmp --model llava:13b
"""
from pathlib import Path
from .vision import analyze_image_ollama
console.print("[bold cyan]👁️ Ollama Vision Analysis[/bold cyan]\n")
try:
image_path = Path(image)
console.print(f" Image: {image_path.name}")
console.print(f" Model: {model}")
console.print(f" Prompt: {prompt}\n")
with console.status("[bold green]Analyzing image..."):
result = analyze_image_ollama(
image_path,
prompt=prompt,
model=model,
host=host,
port=port,
)
if result['success']:
console.print(f"[bold green]✓ Analysis complete[/bold green]\n")
console.print(result['response'])
else:
console.print(f"[bold red]✗ Analysis failed[/bold red]")
console.print(f" {result['error']}")
sys.exit(1)
except Exception as e:
console.print(f"[bold red]✗ Analysis failed: {e}[/bold red]")
sys.exit(1)
@vision.command('download-models')
@click.option('--model', type=click.Choice(['vit_b', 'vit_l', 'vit_h', 'all']),
default='vit_b', help='Model to download (default: vit_b)')
@click.option('--checkpoint-dir', type=click.Path(),
help='Directory to save checkpoints (default: fixtures/models/)')
@click.option('--force', is_flag=True,
help='Force re-download even if checkpoint exists')
@click.option('--no-verify', is_flag=True,
help='Skip SHA256 verification (not recommended)')
def vision_download_models(model, checkpoint_dir, force, no_verify):
"""Download SAM model checkpoints.
Downloads SAM (Segment Anything Model) checkpoints to fixtures/models/
for DVC tracking and reproducibility.
Model sizes:
vit_b: 358 MB - Base model (fastest)
vit_l: 1.2 GB - Large model (balanced)
vit_h: 2.4 GB - Huge model (best quality)
Examples:
hae vision download-models
hae vision download-models --model vit_l
hae vision download-models --model all
hae vision download-models --checkpoint-dir ~/.cache/sam
"""
from pathlib import Path
from .vision import (
download_checkpoint,
download_all_checkpoints,
list_available_models,
get_default_checkpoint_dir
)
console.print("[bold cyan]📥 SAM Model Download[/bold cyan]\n")
# Get target directory
if checkpoint_dir:
target_dir = Path(checkpoint_dir)
else:
target_dir = get_default_checkpoint_dir()
# Check if in project for DVC instructions
from .vision.download import get_project_root
in_project = get_project_root() is not None
# Show available models
available_models = list_available_models()
console.print("[bold]Available models:[/bold]")
for model_type, info in available_models.items():
console.print(f" {model_type}: {info['size_mb']} MB")
console.print(f"\n[bold]Download location:[/bold] {target_dir}")
console.print()
try:
if model == 'all':
console.print("[bold]Downloading all models...[/bold]\n")
results = download_all_checkpoints(
checkpoint_dir=target_dir if checkpoint_dir else None,
force=force,
verify=not no_verify
)
# Show results
console.print()
for model_type, result in results.items():
if result['success']:
status = result.get('status', 'downloaded')
if status == 'already_exists':
console.print(f"[green]✓[/green] {model_type}: Already exists")
else:
console.print(f"[green]✓[/green] {model_type}: Downloaded successfully")
else:
console.print(f"[red]✗[/red] {model_type}: {result.get('error', 'Failed')}")
else:
console.print(f"[bold]Downloading {model}...[/bold]\n")
result = download_checkpoint(
model_type=model,
checkpoint_dir=target_dir if checkpoint_dir else None,
force=force,
verify=not no_verify
)
if result['success']:
console.print(f"\n[bold green]✓ Download complete[/bold green]")
console.print(f" Path: {result['path']}")
status = result.get('status', 'downloaded')
if status == 'already_exists':
console.print(f" Status: Already exists and verified")
else:
console.print(f" Status: Downloaded and verified")
else:
console.print(f"[bold red]✗ Download failed[/bold red]")
console.print(f" {result.get('error', 'Unknown error')}")
sys.exit(1)
# Show DVC tracking instructions if in project
if in_project and not checkpoint_dir:
console.print("\n[bold cyan]📦 DVC Tracking[/bold cyan]")
console.print("To track checkpoints with DVC, run:")
console.print(" [bold]dvc add fixtures/models/[/bold]")
console.print(" [bold]git add fixtures/models.dvc .gitignore[/bold]")
console.print(" [bold]git commit -m 'Add SAM model checkpoints'[/bold]")
console.print(" [bold]dvc push[/bold]")
console.print("\nTeam members can then download with:")
console.print(" [bold]dvc pull[/bold]")
except Exception as e:
console.print(f"[bold red]✗ Download failed: {e}[/bold red]")
sys.exit(1)
# ============================================================================
# Pipeline and Experiment Commands
# ============================================================================
@main.group()
def pipeline():
"""Pipeline orchestration commands.
Execute multi-stage workflows defined in YAML files.
"""
pass
@pipeline.command('run')
@click.argument('pipeline_id_or_path')
@click.option('--vars', multiple=True,
help='Variable overrides in key=value format (can be repeated)')
@click.option('--dry-run', is_flag=True,
help='Show commands without executing them')
def pipeline_run(pipeline_id_or_path, vars, dry_run):
"""Run a pipeline by ID or YAML file path.
Accepts either:
- Pipeline ID (looks in definitions/pipelines/)
- Full path to pipeline YAML file
Execute multi-stage workflows with variable substitution and
sequential stage execution.
Examples:
# By pipeline ID
hae pipeline run smoke-test
hae pipeline run eval-ad --vars model_path=./checkpoints/best.pt
# By full path
hae pipeline run definitions/pipelines/smoke-test.yaml
hae pipeline run pipelines/smoke-test.yaml --vars model_size=UNet_S
hae pipeline run pipelines/smoke-test.yaml --vars model_size=UNet_M --vars epochs=5
hae pipeline run pipelines/smoke-test.yaml --dry-run
"""
from pathlib import Path
from .commands.pipeline import run_pipeline, parse_vars_overrides
# Check if it's a path that exists
input_path = Path(pipeline_id_or_path)
if input_path.exists() and input_path.is_file():
yaml_path = input_path
else:
# Treat as pipeline ID - look in definitions directory
definitions_dir = Path("definitions/pipelines")
yaml_path = definitions_dir / f"{pipeline_id_or_path}.yaml"
if not yaml_path.exists():
console.print(f"[red]❌ Pipeline not found: {pipeline_id_or_path}[/red]")
console.print(f"\n[yellow]Tried:[/yellow]")
console.print(f" - {yaml_path}")
console.print(f"\n[cyan]Available pipelines:[/cyan]")
# Show available pipelines
from rich.table import Table
table = Table(show_header=False, box=None)
if definitions_dir.exists():
for pipeline_file in sorted(definitions_dir.glob("*.yaml")):
pipeline_id = pipeline_file.stem
table.add_row(f" • {pipeline_id}")
console.print(table)
console.print(f"\n[dim]Usage: hae pipeline run <id> or hae pipeline run <path>[/dim]")
sys.exit(1)
vars_dict = parse_vars_overrides(list(vars)) if vars else None
success = run_pipeline(yaml_path, vars_dict, dry_run=dry_run)
sys.exit(0 if success else 1)
@pipeline.command('list')
@click.option('--definitions', is_flag=True, help='List only pipelines from definitions/')
def pipeline_list(definitions: bool):
"""List all available pipelines.
By default, lists all pipelines from definitions/pipelines/*.yaml.
Examples:
hae pipeline list # List all pipelines
hae pipeline list --definitions # List only definitions directory pipelines
"""
from pathlib import Path
import yaml
from rich.table import Table
all_pipelines = []
definitions_dir = Path("definitions/pipelines")
# List pipelines from definitions directory (default behavior)
if definitions_dir.exists():
for yaml_file in sorted(definitions_dir.glob("*.yaml")):
try:
with open(yaml_file, encoding='utf-8') as f:
config = yaml.safe_load(f)
pipeline_id = yaml_file.stem
name = config.get('name', pipeline_id)
description = config.get('description', 'No description')
# Get relative path
try:
rel_path = str(yaml_file.relative_to(Path.cwd()))
except ValueError:
rel_path = str(yaml_file)
# Count stages
stages = config.get('stages', [])
num_stages = len(stages) if isinstance(stages, list) else 0
all_pipelines.append({
'id': pipeline_id,
'name': name,
'description': description,
'path': rel_path,
'stages': num_stages
})
except Exception as e:
console.print(f"[yellow]Warning: Could not load {yaml_file}: {e}[/yellow]")
if not all_pipelines:
console.print("[yellow]No pipelines found[/yellow]")
console.print(f"[dim]Looking in: {definitions_dir if definitions_dir.exists() else 'definitions/pipelines'}[/dim]")
return
# Display table
table = Table(title="Available Pipelines", show_header=True, header_style="bold", box=None, padding=(0, 1))
table.add_column("ID", style="cyan", no_wrap=True, width=30)
table.add_column("Name", style="green", no_wrap=True, width=40)
table.add_column("Stages", style="yellow", width=8, justify="right")
table.add_column("Description", style="dim", width=50)
for pipeline in sorted(all_pipelines, key=lambda x: x['id']):
# Truncate name and description
name = pipeline['name'].replace('\n', ' ').replace('\r', ' ').strip()
if len(name) > 38:
name = name[:35] + "..."
desc = pipeline['description'].replace('\n', ' ').replace('\r', ' ').strip()
if len(desc) > 48:
desc = desc[:45] + "..."
pipeline_id = pipeline['id']
if len(pipeline_id) > 28:
pipeline_id = pipeline_id[:25] + "..."
table.add_row(
pipeline_id,
name,
str(pipeline['stages']),
desc
)
console.print()
console.print(table)
console.print()
@main.group()
def experiment():
"""Experiment management commands.
Run parameter sweeps and manage experimental workflows.
"""
pass
@experiment.command('list')
@click.option('--tag', multiple=True, help='Filter by tag (can specify multiple)')
@click.option('--exclude-tag', multiple=True, help='Exclude experiments with tag')
@click.option('--search', help='Search in ID and description')
@click.option('--numbered/--no-numbered', default=True, help='Show numbered list for selection')
@click.option('--rebuild', is_flag=True, help='Rebuild database before listing')
@click.option('--no-sync', is_flag=True, help='Skip auto-sync (use cached database)')
def experiment_list(tag, exclude_tag, search, numbered, rebuild, no_sync):
"""List all available experiments from database.
Uses the experiment database for fast filtering and search.
Database is auto-synced incrementally unless --no-sync is specified.
Examples:
hae experiment list # Auto-syncs, then lists all
hae experiment list --tag synthetic # Only synthetic experiments
hae experiment list --tag training --tag carpet # Multiple tags (AND)
hae experiment list --exclude-tag test # Exclude test experiments
hae experiment list --search mvtec # Search in ID/description
hae experiment list --rebuild # Full rebuild (slow)
hae experiment list --no-sync # Skip sync (use cached data)
"""
from hae.db import ExperimentDB
from hae.migration import sync_database_incremental
import re
from pathlib import Path
# Initialize database
db = ExperimentDB()
# Sync strategy: rebuild, auto-sync, or skip
if rebuild:
# Full rebuild - clear and re-index everything
console.print("[dim]Rebuilding experiment database...[/dim]")
db.index_definitions()
def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0]
console.print(f"[dim]Indexed {def_count} experiments[/dim]\n")
elif not no_sync:
# Smart auto-sync: only process changed/new files (fast!)
new, updated, unchanged = sync_database_incremental(db)
if new + updated > 0:
console.print(f"[dim]↻ Synced: {new} new, {updated} updated[/dim]")
# else: no_sync is True, skip sync entirely
# Query experiments with filters
experiments = db.query_definitions(
tags=list(tag) if tag else None,
exclude_tags=list(exclude_tag) if exclude_tag else None,
search=search
)
db.close()
if not experiments:
console.print("[yellow]No experiments found[/yellow]")
if tag:
console.print(f"[dim]Try removing tag filters: {', '.join(tag)}[/dim]")
if search:
console.print(f"[dim]Try removing search filter: {search}[/dim]")
return
# Build ID → number mapping for dependency resolution
id_to_number = {exp.id: idx + 1 for idx, exp in enumerate(experiments)}
# Parse dependencies from YAML files
dependencies = {}
for exp in experiments:
yaml_path = Path(exp.yaml_path)
try:
with open(yaml_path, encoding='utf-8') as f:
config = yaml.safe_load(f)
deps = []
# Extract depends_on from all experiment definitions
for exp_def in config.get('experiments', []):
if 'depends_on' in exp_def:
dep_list = exp_def['depends_on']
if isinstance(dep_list, str):
dep_list = [dep_list]
deps.extend(dep_list)
# Deduplicate and resolve to numbers (only for experiments in current view)
dep_ids = list(set(deps))
dep_numbers = []
for dep_id in dep_ids:
if dep_id in id_to_number:
dep_numbers.append(id_to_number[dep_id])
dependencies[exp.id] = sorted(dep_numbers) # Sort for consistency
except Exception:
dependencies[exp.id] = []
# Display table (compact format)
table = Table(title="Available Experiments", show_header=True, header_style="bold", box=None, padding=(0, 1))
if numbered:
table.add_column("#", style="dim", width=3, no_wrap=True)
table.add_column("ID", style="cyan", no_wrap=True, width=28)
table.add_column("Tags", style="yellow", width=25)
table.add_column("Last Run", style="dim", width=16, no_wrap=True)
table.add_column("Deps", style="magenta", width=12, no_wrap=True)
for idx, exp in enumerate(experiments, start=1):
# Format last run timestamp
if exp.last_run_timestamp:
# Parse ISO timestamp: 2025-11-15T23:21:12Z
match = re.search(r'(\d{4}-\d{2}-\d{2})T(\d{2}:\d{2}):\d{2}', exp.last_run_timestamp)
if match:
last_run_str = f"{match.group(1)} {match.group(2)}"
else:
last_run_str = exp.last_run_timestamp[:16]
else:
last_run_str = "-"
# Format tags (compact)
if exp.tags:
# Show first 3 tags, or indicate more
tags_display = ', '.join(exp.tags[:3])
if len(exp.tags) > 3:
tags_display += f" +{len(exp.tags) - 3}"
# Truncate to fit column
if len(tags_display) > 23:
tags_display = tags_display[:20] + "..."
else:
tags_display = "-"
# Conditional ID/Name display (Option 2: show name only when different)
if exp.name == exp.id:
id_display = exp.id
else:
id_display = f"{exp.id} ({exp.name})"
# Truncate ID to fit column
if len(id_display) > 26:
id_display = id_display[:23] + "..."
# Format dependencies (show numbers only)
dep_numbers = dependencies.get(exp.id, [])
if dep_numbers:
# Limit to first 3 dependencies
deps_display = ",".join(f"#{n}" for n in dep_numbers[:3])
if len(dep_numbers) > 3:
deps_display += ",..."
else:
deps_display = "-"
row_data = [id_display, tags_display, last_run_str, deps_display]
if numbered:
row_data.insert(0, str(idx))
table.add_row(*row_data)
console.print(table)
# Show summary with filter info
summary_parts = [f"Total: {len(experiments)} experiments"]
if tag:
summary_parts.append(f"tags: {', '.join(tag)}")
if exclude_tag:
summary_parts.append(f"excluded: {', '.join(exclude_tag)}")
if search:
summary_parts.append(f"search: {search}")
console.print(f"\n[dim]{' | '.join(summary_parts)}[/dim]")
if numbered:
console.print(f"[dim]Run: hae experiment run <#> or hae experiment run <id>[/dim]")
else:
console.print(f"[dim]Run: hae experiment run <id>[/dim]")
@experiment.command('run')
@click.argument('experiment_id_or_path')
@click.option('--dry-run/--no-dry-run', default=False, help='Show execution plan without running')
@click.option("-v", "--variables", multiple=True, help="Variables as key=value")
def experiment_run(experiment_id_or_path, dry_run, variables):
"""Run an experiment by ID, number, fuzzy match, or YAML file path.
Accepts:
- Number from `hae experiment list`: hae experiment run 3
- Experiment ID: hae experiment run synthetic-carpet-train
- Fuzzy match (substring): hae experiment run carpet
- Full path: hae experiment run definitions/experiments/synthetic-carpet-train.yaml
Supports multiple formats:
1. Orchestration format (definitions/experiments/*.yaml) - with `experiments:` key
2. Sweep format (definitions/experiments/*.yaml) - with `sweep:` and `cmd:` keys
Examples:
hae experiment run 3 # Run by number (from list)
hae experiment run carpet # Fuzzy match
hae experiment run synthetic-carpet-train # By exact ID
hae experiment run definitions/experiments/synthetic-carpet-train.yaml # By path
"""
from pathlib import Path
import yaml
from .commands.experiment import run_experiment_sweep, run_experiment_orchestration
from hae.db import ExperimentDB
# Check if it's a path that exists
input_path = Path(experiment_id_or_path)
if input_path.exists() and input_path.is_file():
yaml_path = input_path
else:
# Try to resolve using database
db = ExperimentDB()
# Auto-rebuild if database is empty
def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0]
if def_count == 0:
console.print("[dim]Rebuilding experiment database...[/dim]")
db.index_definitions()
# Try numbered selection (e.g., "3")
if experiment_id_or_path.isdigit():
num = int(experiment_id_or_path)
all_experiments = db.query_definitions()
if 1 <= num <= len(all_experiments):
exp = all_experiments[num - 1]
console.print(f"[dim]Selected #{num}: {exp.id}[/dim]")
yaml_path = Path(exp.yaml_path)
else:
console.print(f"[red]❌ Number out of range: {num} (valid: 1-{len(all_experiments)})[/red]")
console.print(f"[dim]Run 'hae experiment list' to see numbered experiments[/dim]")
db.close()
sys.exit(1)
else:
# Try exact ID match first
definitions_dir = Path("definitions/experiments")
yaml_path = definitions_dir / f"{experiment_id_or_path}.yaml"
if not yaml_path.exists():
# Try fuzzy matching
matches = db.query_definitions(search=experiment_id_or_path)
if len(matches) == 0:
console.print(f"[red]❌ No experiments found matching: {experiment_id_or_path}[/red]")
console.print(f"\n[yellow]Tried:[/yellow]")
console.print(f" - Exact ID: {experiment_id_or_path}")
console.print(f" - Fuzzy search: {experiment_id_or_path}")
console.print(f"\n[dim]Run 'hae experiment list' to see available experiments[/dim]")
db.close()
sys.exit(1)
elif len(matches) == 1:
exp = matches[0]
console.print(f"[dim]Matched: {exp.id}[/dim]")
yaml_path = Path(exp.yaml_path)
else:
# Multiple matches - show user
console.print(f"[yellow]Multiple experiments match '{experiment_id_or_path}':[/yellow]\n")
from rich.table import Table
table = Table(show_header=True, box=None)
table.add_column("#", style="dim", width=3)
table.add_column("ID", style="cyan")
table.add_column("Tags", style="yellow")
for idx, exp in enumerate(matches, 1):
tags_str = ', '.join(exp.tags[:2]) if exp.tags else "-"
if exp.tags and len(exp.tags) > 2:
tags_str += f" +{len(exp.tags) - 2}"
table.add_row(str(idx), exp.id, tags_str)
console.print(table)
console.print(f"\n[dim]Please specify a more specific ID or use the number[/dim]")
db.close()
sys.exit(1)
db.close()
# Detect format: sweep+experiments (Format 3) > experiments (Format 1) > sweep+cmd (Format 2)
try:
with open(yaml_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
console.print(f"[red]Error loading YAML: {str(e)}[/red]")
sys.exit(1)
from .commands.pipeline import parse_vars_overrides
var_overrides = {}
if variables:
var_overrides = parse_vars_overrides(list(variables))
if 'sweep' in config and 'experiments' in config:
# Format 3: Sweep-of-Orchestrations (sweep at top level, runs entire workflow per trial)
console.print("[cyan]Detected sweep-of-orchestrations format (sweep + experiments)[/cyan]")
from .commands.experiment import run_experiment_sweep_orchestration
success = run_experiment_sweep_orchestration(yaml_path, dry_run=dry_run, var_overrides=var_overrides)
elif 'experiments' in config:
# Format 1: Orchestration (multi-experiment workflow)
console.print("[cyan]Detected orchestration experiment format[/cyan]")
success = run_experiment_orchestration(yaml_path, dry_run=dry_run, var_overrides=var_overrides)
elif 'sweep' in config and 'cmd' in config:
# Format 2: Sweep (parameter sweep with single command)
console.print("[cyan]Detected sweep experiment format[/cyan]")
success = run_experiment_sweep(yaml_path, dry_run=dry_run, var_overrides=var_overrides)
else:
console.print("[red]Error: Unrecognized experiment format[/red]")
console.print("[yellow]Expected one of:[/yellow]")
console.print(" 1. 'experiments' key (orchestration format)")
console.print(" 2. 'sweep' + 'cmd' keys (sweep format)")
console.print(" 3. 'sweep' + 'experiments' keys (sweep-of-orchestrations format)")
sys.exit(1)
sys.exit(0 if success else 1)
def _run_experiment_from_yaml(yaml_path: Path, dry_run: bool = False) -> bool:
"""Internal function to run an experiment from a YAML file.
Args:
yaml_path: Path to experiment YAML file
dry_run: If True, show what would be executed without running
Returns:
True if experiment succeeded, False otherwise
"""
import yaml
from .commands.experiment import run_experiment_sweep, run_experiment_orchestration, run_experiment_sweep_orchestration
# Load config to detect format
try:
with open(yaml_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
console.print(f"[red]Error loading YAML: {str(e)}[/red]")
return False
# Detect format: sweep+experiments (Format 3) > experiments (Format 1) > sweep+cmd (Format 2)
if 'sweep' in config and 'experiments' in config:
# Format 3: Sweep-of-Orchestrations
console.print("[cyan]Detected sweep-of-orchestrations format (sweep + experiments)[/cyan]")
return run_experiment_sweep_orchestration(yaml_path, dry_run=dry_run)
elif 'experiments' in config:
# Format 1: Orchestration
console.print("[cyan]Detected orchestration experiment format[/cyan]")
return run_experiment_orchestration(yaml_path, dry_run=dry_run)
elif 'sweep' in config and 'cmd' in config:
# Format 2: Sweep
console.print("[cyan]Detected sweep experiment format[/cyan]")
if dry_run:
console.print("[yellow]Dry-run not supported for sweep format[/yellow]")
return run_experiment_sweep(yaml_path)
else:
console.print("[red]Error: Unrecognized experiment format[/red]")
console.print("[yellow]Expected one of:[/yellow]")
console.print(" 1. 'experiments' key (orchestration format)")
console.print(" 2. 'sweep' + 'cmd' keys (sweep format)")
console.print(" 3. 'sweep' + 'experiments' keys (sweep-of-orchestrations format)")
return False
@experiment.command('rerun')
@click.argument('execution_path')
@click.option('--execution', help='Select specific execution by index (1=newest, 2=2nd newest) or execution ID')
@click.option('--dry-run/--no-dry-run', default=False, help='Show what would be executed without running')
def experiment_rerun(execution_path, execution, dry_run):
"""Rerun an experiment from a previous execution record.
You can rerun an experiment using:
1. Execution directory path
2. Experiment ID (finds most recent execution)
3. Snapshot YAML file path
Use --execution to select a specific execution when using experiment ID:
- By index: --execution 2 (2nd most recent)
- By execution ID: --execution 232056_train-and-compare-evaluators or --execution 232056
Examples:
# Rerun from execution directory
hae experiment rerun experiment-history/2025/01/22/232056_train-and-compare-evaluators
# Rerun from snapshot YAML
hae experiment rerun experiment-history/2025/01/22/232056_train-and-compare-evaluators/attachments/20250122232056_train-and-compare-evaluators.yaml
# Rerun most recent execution of an experiment (by ID)
hae experiment rerun train-and-compare-evaluators
# Rerun 2nd most recent execution
hae experiment rerun train-and-compare-evaluators --execution 2
# Rerun specific execution by timestamp
hae experiment rerun train-and-compare-evaluators --execution 232056
"""
exec_path = Path(execution_path)
# Case 1: It's a snapshot YAML file
if exec_path.exists() and exec_path.is_file() and exec_path.suffix in ['.yaml', '.yml']:
snapshot_yaml = exec_path
console.print(f"[cyan]Rerunning from snapshot: {snapshot_yaml}[/cyan]\n")
success = _run_experiment_from_yaml(snapshot_yaml, dry_run=dry_run)
sys.exit(0 if success else 1)
# Case 2: It's an execution directory
elif exec_path.exists() and exec_path.is_dir():
# Look for snapshot YAML in attachments/
attachments_dir = exec_path / "attachments"
if attachments_dir.exists():
# Find the experiment snapshot (usually named with timestamp_experiment-name.yaml)
snapshot_files = list(attachments_dir.glob("*_*.yaml")) + list(attachments_dir.glob("*_*.yml"))
if snapshot_files:
# Use the most recent snapshot
snapshot_yaml = max(snapshot_files, key=lambda p: p.stat().st_mtime)
console.print(f"[cyan]Rerunning from execution record: {exec_path.name}[/cyan]")
console.print(f"[dim]Using snapshot: {snapshot_yaml.name}[/dim]\n")
success = _run_experiment_from_yaml(snapshot_yaml, dry_run=dry_run)
sys.exit(0 if success else 1)
else:
console.print(f"[red]No snapshot YAML found in {attachments_dir}[/red]")
sys.exit(1)
else:
console.print(f"[red]No attachments directory found in {exec_path}[/red]")
sys.exit(1)
# Case 3: It's an experiment ID - find execution(s)
else:
# Try to find execution(s) with this ID
executions_dir = Path("experiment-history")
if not executions_dir.exists():
console.print(f"[red]No history directory found[/red]")
sys.exit(1)
# Search for execution directories matching this ID
matching_execs = []
exp_id = str(execution_path) # Use the original string
for exec_dir in executions_dir.rglob("*"):
if exec_dir.is_dir() and "_" in exec_dir.name:
# Extract experiment name from directory (after first _)
parts = exec_dir.name.split("_", 1)
if len(parts) == 2:
exec_name = parts[1]
# Normalize names for comparison
exec_name_norm = exec_name.replace("-", "_").replace(" ", "_")
exp_id_norm = exp_id.replace("-", "_").replace(" ", "_")
if exec_name_norm == exp_id_norm or exec_name == exp_id:
matching_execs.append(exec_dir)
if not matching_execs:
console.print(f"[red]No execution found for experiment: {exp_id}[/red]")
console.print(f"[yellow]Try: hae experiment list (to see available experiments)[/yellow]")
sys.exit(1)
# Sort by modification time (most recent first)
matching_execs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
# Select specific execution if --execution provided
if execution:
# Try to parse as integer index (1-based)
try:
exec_index = int(execution) - 1 # Convert to 0-based
# Check if it's a valid index range
if 0 <= exec_index < len(matching_execs):
# Valid index - use it
selected_exec = matching_execs[exec_index]
console.print(f"[cyan]Found {len(matching_execs)} execution(s) for '{exp_id}'[/cyan]")
console.print(f"[cyan]Rerunning execution #{execution}: {selected_exec.name}[/cyan]\n")
else:
# Integer parsed but out of range - treat as pattern instead
raise ValueError("Index out of range, try pattern matching")
except ValueError:
# Not a valid integer index, treat as execution ID pattern
# Match execution directories containing this pattern
filtered_execs = []
for exec_dir in matching_execs:
# Check if execution pattern is in the directory name
if execution in exec_dir.name or execution.replace("-", "_") in exec_dir.name:
filtered_execs.append(exec_dir)
if not filtered_execs:
console.print(f"[red]No execution matching '{execution}' found for experiment '{exp_id}'[/red]")
console.print(f"[yellow]Available executions:[/yellow]")
for i, e in enumerate(matching_execs[:5], 1):
console.print(f" {i}. {e.name}")
if len(matching_execs) > 5:
console.print(f" ... and {len(matching_execs) - 5} more")
sys.exit(1)
selected_exec = filtered_execs[0] # Use most recent matching
console.print(f"[cyan]Found {len(filtered_execs)} execution(s) matching '{execution}'[/cyan]")
console.print(f"[cyan]Rerunning: {selected_exec.name}[/cyan]\n")
else:
# No --execution flag: use most recent
selected_exec = matching_execs[0]
console.print(f"[cyan]Found {len(matching_execs)} execution(s) for '{exp_id}'[/cyan]")
console.print(f"[cyan]Rerunning most recent: {selected_exec.name}[/cyan]\n")
most_recent = selected_exec # Rename for compatibility with code below
# Process the execution directory
attachments_dir = most_recent / "attachments"
if attachments_dir.exists():
# Find the experiment snapshot (prefer experiment YAML, fallback to any YAML)
snapshot_files = list(attachments_dir.glob("*_*.yaml")) + list(attachments_dir.glob("*_*.yml"))
if snapshot_files:
# Prefer experiment snapshot (usually named with experiment name)
experiment_snapshots = [f for f in snapshot_files if 'experiment' in f.stem.lower() or exp_id.replace('-', '_') in f.stem.lower()]
if experiment_snapshots:
snapshot_yaml = max(experiment_snapshots, key=lambda p: p.stat().st_mtime)
else:
# Fallback to most recent snapshot
snapshot_yaml = max(snapshot_files, key=lambda p: p.stat().st_mtime)
console.print(f"[dim]Using snapshot: {snapshot_yaml.name}[/dim]\n")
success = _run_experiment_from_yaml(snapshot_yaml, dry_run=dry_run)
sys.exit(0 if success else 1)
else:
console.print(f"[red]No snapshot YAML found in {attachments_dir}[/red]")
sys.exit(1)
else:
console.print(f"[red]No attachments directory found in {most_recent}[/red]")
sys.exit(1)
@experiment.command('history')
@click.argument('experiment_id', required=False)
@click.option('--limit', type=int, default=10, help='Maximum number of executions to show')
@click.option('--status', help='Filter by status (completed, failed, running)')
@click.option('--rebuild', is_flag=True, help='Rebuild database before querying')
@click.option('--no-sync', is_flag=True, help='Skip auto-sync (use cached database)')
def experiment_history(experiment_id, limit, status, rebuild, no_sync):
"""List execution history for experiments.
Shows execution history with timestamps, status, and paths.
Database is auto-synced incrementally unless --no-sync is specified.
Examples:
hae experiment history # Auto-syncs, show all recent
hae experiment history synthetic-carpet-train # Show specific experiment
hae experiment history --status completed --limit 20 # Show recent completed runs
hae experiment history --rebuild # Full rebuild (slow)
hae experiment history --no-sync # Skip sync (use cached data)
"""
from hae.db import ExperimentDB
from hae.migration import sync_database_incremental
import re
# Initialize database
db = ExperimentDB()
# Sync strategy: rebuild, auto-sync, or skip
if rebuild:
# Full rebuild - clear and re-index everything
console.print("[dim]Rebuilding experiment database...[/dim]")
db.index_definitions()
db.index_executions()
exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0]
console.print(f"[dim]Indexed {exec_count} executions[/dim]\n")
elif not no_sync:
# Smart auto-sync: only process changed/new files (fast!)
new, updated, unchanged = sync_database_incremental(db)
# Also sync executions (always incremental)
db.index_executions()
if new + updated > 0:
console.print(f"[dim]↻ Synced: {new} new, {updated} updated definitions[/dim]")
# else: no_sync is True, skip sync entirely
# Query executions with filters
executions = db.query_executions(
definition_id=experiment_id,
status=status,
limit=limit
)
db.close()
if not executions:
if experiment_id:
console.print(f"[yellow]No executions found for experiment: {experiment_id}[/yellow]")
else:
console.print("[yellow]No executions found[/yellow]")
return
# Display table
title = f"Executions for '{experiment_id}'" if experiment_id else "Recent Executions"
table = Table(title=title)
table.add_column("#", style="dim", width=4)
table.add_column("Execution ID", style="cyan", no_wrap=True, max_width=28)
table.add_column("Status", style="yellow", width=10)
table.add_column("Timestamp", style="dim", width=16)
table.add_column("Path", style="dim", max_width=35, overflow="ellipsis")
for idx, exec in enumerate(executions, 1):
# Status with icon
status_icon = "✅" if exec.status == "completed" else "❌" if exec.status == "failed" else "⏳"
status_str = f"{status_icon} {exec.status}"
# Format timestamp (ISO format: 2025-11-15T23:21:12Z)
timestamp_str = exec.timestamp
match = re.search(r'(\d{4}-\d{2}-\d{2})T(\d{2}:\d{2}):\d{2}', exec.timestamp)
if match:
timestamp_str = f"{match.group(1)} {match.group(2)}"
else:
timestamp_str = exec.timestamp[:16]
# Format execution ID (truncate if needed)
exec_id = exec.id
if len(exec_id) > 26:
exec_id = exec_id[:23] + "..."
# Format path (relative)
from pathlib import Path
try:
rel_path = str(Path(exec.execution_path).relative_to(Path.cwd()))
except ValueError:
rel_path = exec.execution_path
if len(rel_path) > 33:
rel_path = "..." + rel_path[-30:]
table.add_row(
str(idx),
exec_id,
status_str,
timestamp_str,
rel_path
)
console.print(table)
# Summary
summary_parts = [f"Total: {len(executions)} execution(s)"]
if experiment_id:
summary_parts.append(f"experiment: {experiment_id}")
if status:
summary_parts.append(f"status: {status}")
console.print(f"\n[dim]{' | '.join(summary_parts)}[/dim]")
console.print(f"[dim]Rerun: hae experiment rerun <execution_id> or hae experiment rerun <path>[/dim]")
@experiment.command('show')
@click.argument('experiment_id')
@click.option('--execution', help='Select specific execution by index (1=newest) or execution ID')
def experiment_show(experiment_id, execution):
"""Show detailed information about a specific execution.
Examples:
hae experiment show train-and-compare-evaluators # Show newest execution
hae experiment show train-and-compare-evaluators --execution 2 # Show 2nd newest
hae experiment show train-and-compare-evaluators --execution 232056
"""
import re
executions_dir = Path("experiment-history")
if not executions_dir.exists():
console.print("[yellow]No history directory found[/yellow]")
return
# Find all executions matching this experiment ID
matching_execs = []
exp_id_norm = experiment_id.replace("-", "_").replace(" ", "_")
for exec_dir in executions_dir.rglob("*"):
if exec_dir.is_dir() and "_" in exec_dir.name:
parts = exec_dir.name.split("_", 1)
if len(parts) == 2:
exec_name = parts[1]
exec_name_norm = exec_name.replace("-", "_").replace(" ", "_")
if exec_name_norm == exp_id_norm or exec_name == experiment_id:
matching_execs.append(exec_dir)
if not matching_execs:
console.print(f"[yellow]No executions found for experiment: {experiment_id}[/yellow]")
return
# Sort by modification time (newest first)
matching_execs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
# Select specific execution if --execution provided, otherwise use newest
if execution:
# Try to parse as integer index (1-based)
try:
exec_index = int(execution) - 1 # Convert to 0-based
if 0 <= exec_index < len(matching_execs):
selected_exec = matching_execs[exec_index]
else:
raise ValueError("Index out of range, try pattern matching")
except ValueError:
# Not a valid integer index, treat as execution ID pattern
filtered_execs = []
for exec_dir in matching_execs:
if execution in exec_dir.name or execution.replace("-", "_") in exec_dir.name:
filtered_execs.append(exec_dir)
if not filtered_execs:
console.print(f"[red]No execution found matching: {execution}[/red]")
console.print(f"[dim]Found {len(matching_execs)} execution(s) for '{experiment_id}'[/dim]")
console.print(f"[dim]Try: hae experiment history {experiment_id}[/dim]")
sys.exit(1)
if len(filtered_execs) > 1:
console.print(f"[yellow]Multiple executions match '{execution}':[/yellow]")
for i, e in enumerate(filtered_execs, 1):
console.print(f" {i}. {e.name}")
console.print(f"[dim]Use --execution <number> to select a specific one[/dim]")
sys.exit(1)
selected_exec = filtered_execs[0]
else:
# No execution specified, use newest
selected_exec = matching_execs[0]
# Read experiment.org file
org_file = selected_exec / "experiment.org"
if not org_file.exists():
console.print(f"[red]experiment.org not found in {selected_exec}[/red]")
return
with open(org_file) as f:
org_content = f.read()
# Parse metadata from org file
def extract_property(name):
match = re.search(rf':{name}: (.+)', org_content)
return match.group(1) if match else "N/A"
# Extract basic info
exp_name = extract_property("EXPERIMENT_NAME")
status = extract_property("STATUS")
started = extract_property("STARTED")
completed = extract_property("COMPLETED")
duration = extract_property("DURATION")
# Extract system info
git_commit = extract_property("GIT_COMMIT")
git_branch = extract_property("GIT_BRANCH")
git_dirty = extract_property("GIT_DIRTY")
python_version = extract_property("PYTHON_VERSION")
cuda_version = extract_property("CUDA_VERSION")
pytorch_version = extract_property("PYTORCH_VERSION")
# Display information
console.print(f"\n[bold cyan]Execution: {selected_exec.name}[/bold cyan]")
console.print(f"[bold]Experiment:[/bold] {exp_name}")
console.print(f"[bold]Status:[/bold] {status}")
console.print(f"[bold]Started:[/bold] {started}")
console.print(f"[bold]Completed:[/bold] {completed}")
console.print(f"[bold]Duration:[/bold] {duration}")
console.print(f"\n[bold]System Information:[/bold]")
console.print(f" Git Commit: {git_commit}")
console.print(f" Git Branch: {git_branch}")
console.print(f" Git Dirty: {git_dirty}")
console.print(f" Python: {python_version}")
console.print(f" PyTorch: {pytorch_version}")
console.print(f" CUDA: {cuda_version}")
console.print(f"\n[bold]Path:[/bold] {selected_exec}")
console.print(f"[dim]View full details: cat {org_file}[/dim]")
@experiment.command('status')
@click.argument('experiment_id')
def experiment_status(experiment_id):
"""Show detailed status for an experiment.
Shows all execution history, including running, completed, and failed experiments.
Example:
hae experiment status train-and-compare-evaluators
"""
from .commands.experiment_zettelkasten import is_experiment_running
import re
from rich.table import Table
# Find executions for this experiment
executions_dir = Path("experiment-history")
if not executions_dir.exists():
console.print(f"[yellow]No experiment history found[/yellow]")
sys.exit(1)
# Find matching executions
matching_execs = []
for exec_dir in executions_dir.rglob("*"):
if exec_dir.is_dir() and "_" in exec_dir.name:
parts = exec_dir.name.split("_", 1)
if len(parts) == 2:
exec_name = parts[1]
exec_name_norm = exec_name.replace("-", "_").replace(" ", "_")
exp_id_norm = experiment_id.replace("-", "_").replace(" ", "_")
if exec_name_norm == exp_id_norm or exec_name == experiment_id:
matching_execs.append(exec_dir)
if not matching_execs:
console.print(f"[yellow]No executions found for experiment: {experiment_id}[/yellow]")
sys.exit(1)
# Sort by modification time (most recent first)
matching_execs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
# Count statuses
running_count = 0
succeeded_count = 0
failed_count = 0
unknown_count = 0
# Collect execution details
exec_details = []
for exec_dir in matching_execs:
try:
# Check if running
if is_experiment_running(exec_dir):
status = "running"
running_count += 1
started = "N/A"
completed = "N/A"
else:
# Read status from experiment.org
org_file = exec_dir / "experiment.org"
if org_file.exists():
with open(org_file) as f:
content = f.read()
status_match = re.search(r':STATUS: (\w+)', content)
if status_match:
status = status_match.group(1)
if status == "completed":
succeeded_count += 1
elif status == "failed":
failed_count += 1
else:
unknown_count += 1
else:
status = "unknown"
unknown_count += 1
# Extract timing info
started_match = re.search(r':STARTED: \[([^\]]+)\]', content)
completed_match = re.search(r':COMPLETED: \[([^\]]+)\]', content)
started = started_match.group(1) if started_match else "N/A"
completed = completed_match.group(1) if completed_match else "N/A"
else:
status = "unknown"
unknown_count += 1
started = "N/A"
completed = "N/A"
# Get relative path
try:
rel_path = str(exec_dir.relative_to(Path.cwd()))
except ValueError:
rel_path = str(exec_dir)
exec_details.append({
'path': rel_path,
'status': status,
'started': started,
'completed': completed
})
except Exception as e:
console.print(f"[yellow]Warning: Error reading {exec_dir}: {e}[/yellow]")
# Display summary
total = len(matching_execs)
console.print(f"\n[bold cyan]Experiment: {experiment_id}[/bold cyan]")
console.print(f"\n[bold]Summary:[/bold]")
console.print(f" Total executions: {total}")
console.print(f" Running: [yellow]{running_count}[/yellow]")
console.print(f" Succeeded: [green]{succeeded_count}[/green]")
console.print(f" Failed: [red]{failed_count}[/red]")
if unknown_count > 0:
console.print(f" Unknown: [dim]{unknown_count}[/dim]")
# Display table of recent executions
console.print(f"\n[bold]Recent Executions:[/bold]")
table = Table(show_header=True, header_style="bold", box=None, padding=(0, 1))
table.add_column("Status", style="yellow", width=12)
table.add_column("Started", style="cyan", width=22)
table.add_column("Completed", style="green", width=22)
table.add_column("Path", style="dim")
# Show up to 10 most recent
for exec_info in exec_details[:10]:
status = exec_info['status']
# Color code status
if status == "running":
status_str = "[yellow]⚙ running[/yellow]"
elif status == "completed":
status_str = "[green]✓ completed[/green]"
elif status == "failed":
status_str = "[red]✗ failed[/red]"
else:
status_str = "[dim]? unknown[/dim]"
table.add_row(
status_str,
exec_info['started'],
exec_info['completed'],
exec_info['path']
)
console.print(table)
if len(exec_details) > 10:
console.print(f"\n[dim]... and {len(exec_details) - 10} more executions[/dim]")
console.print(f"\n[dim]View execution: cat <path>/experiment.org[/dim]")
@main.group()
def dataset():
"""Dataset management commands.
List and inspect dataset statistics, splits, and configurations.
"""
pass
@dataset.command('list')
@click.option('--dataset', type=str, help='Filter by dataset type (mvtec, visa, etc.)')
@click.option('--object-category', type=str, help='Filter by object category')
@click.option('--data-dir', type=str, help='Data directory path (default: ./datasets/raw/MVTec/mvtec_ad)')
@click.option('--csv-split-file', type=str, help='CSV split file path')
def datasets_list(dataset, object_category, data_dir, csv_split_file):
"""List available datasets with statistics and split information.
Shows dataset statistics including:
- Dataset type and object categories
- Split counts (train/val/test)
- CSV split file information
- Data directory paths
Examples:
hae datasets list # List all datasets
hae datasets list --dataset mvtec # List MVTec datasets only
hae datasets list --object-category bottle # Show bottle category stats
hae datasets list --data-dir ./my-data # Use custom data directory
"""
from pathlib import Path
import sys
# Import dataset utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.dataset_configs import BASE_DATASET_REGISTRY, get_base_dataset_config
from dioodmi.data.MVTECDataLoader import MVTECDataset
from dioodmi.data.anomaly_core import DatasetConfig, ProcessingParams
except ImportError as e:
console.print(f"[red]Error importing dataset modules: {e}[/red]")
console.print("[yellow]Make sure dioodmi-py is properly installed[/yellow]")
sys.exit(1)
# Determine which datasets to show
if dataset:
if dataset not in BASE_DATASET_REGISTRY:
console.print(f"[red]Unknown dataset: {dataset}[/red]")
console.print(f"[yellow]Available datasets: {', '.join(BASE_DATASET_REGISTRY.keys())}[/yellow]")
sys.exit(1)
datasets_to_check = [dataset]
else:
datasets_to_check = list(BASE_DATASET_REGISTRY.keys())
# Default data directory
if not data_dir:
data_dir = './datasets/raw/MVTec/mvtec_ad'
all_dataset_stats = []
for dataset_name in sorted(datasets_to_check):
base_config = BASE_DATASET_REGISTRY[dataset_name]
# Determine object categories to check
if object_category:
if object_category not in base_config.object_classes:
continue # Skip this dataset if category not found
categories_to_check = [object_category]
else:
categories_to_check = list(base_config.object_classes.keys())
for category in categories_to_check:
# Try to load dataset for each split to get counts
stats = {
'dataset': dataset_name,
'category': category,
'train_count': None,
'val_count': None,
'test_count': None,
'data_dir': data_dir,
'csv_split_file': csv_split_file or base_config.csv_split_file,
'error': None
}
# Try to get counts for each split
for split in ['train', 'val', 'test']:
try:
# Create dataset config
config = DatasetConfig(
name=base_config.name,
object_classes=base_config.object_classes,
csv_columns=base_config.csv_columns,
default_image_size=base_config.default_image_size,
default_crop_size=base_config.default_crop_size,
csv_split_file=csv_split_file if csv_split_file else base_config.csv_split_file,
strategies=base_config.strategies,
temporal_strategies=base_config.temporal_strategies,
mode=split,
object_class=category,
rootdir=data_dir,
anomaly_class='good' if split == 'train' else 'all'
)
# Create processing params
params = ProcessingParams(
image_size=base_config.default_image_size,
crop_size=base_config.default_crop_size
)
# Try to load dataset (without actually loading images)
from dioodmi.data.anomaly_core import parse_csv_data
data_df = parse_csv_data(config, params)
count = len(data_df)
stats[f'{split}_count'] = count
except Exception as e:
# Dataset might not exist or split might not be available
stats['error'] = str(e)
continue
all_dataset_stats.append(stats)
if not all_dataset_stats:
console.print("[yellow]No datasets found matching criteria[/yellow]")
return
# Display table
table = Table(title="Dataset Statistics", show_header=True, header_style="bold", box=None, padding=(0, 1))
table.add_column("Dataset", style="cyan", no_wrap=True, width=15)
table.add_column("Category", style="green", no_wrap=True, width=15)
table.add_column("Train", style="yellow", justify="right", width=8)
table.add_column("Val", style="yellow", justify="right", width=8)
table.add_column("Test", style="yellow", justify="right", width=8)
table.add_column("CSV Split", style="dim", width=20)
for stats in all_dataset_stats:
train_str = str(stats['train_count']) if stats['train_count'] is not None else "-"
val_str = str(stats['val_count']) if stats['val_count'] is not None else "-"
test_str = str(stats['test_count']) if stats['test_count'] is not None else "-"
csv_file = stats['csv_split_file']
if len(csv_file) > 18:
csv_file = csv_file[:15] + "..."
table.add_row(
stats['dataset'],
stats['category'],
train_str,
val_str,
test_str,
csv_file
)
console.print(table)
# Show summary
total_datasets = len(set(s['dataset'] for s in all_dataset_stats))
total_categories = len(all_dataset_stats)
console.print(f"\n[dim]Total: {total_categories} category/dataset combinations ({total_datasets} unique datasets)[/dim]")
if data_dir:
console.print(f"[dim]Data directory: {data_dir}[/dim]")
@dataset.group()
def split():
"""Dataset split management commands.
Create and manage dataset split CSV files.
"""
pass
@split.command('create')
@click.option('--dataset', type=str, help='Dataset type (mvtec, visa, etc.)')
@click.option('--object-category', type=str, help='Object category to process')
@click.option('--data-dir', type=str, required=True, help='Data directory path')
@click.option('--output-csv', type=str, required=True,
help='Output CSV file path (default: definitions/splits/ if relative)')
@click.option('--strategy', type=click.Choice(['static', 'pool']), default='static',
help='Split strategy: static (preserve directory structure) or pool (all samples as pool)')
@click.option('--include-index', is_flag=True, help='Include index column in CSV')
@click.option('--no-masks', is_flag=True, help='Do not search for mask files')
def split_create(dataset, object_category, data_dir, output_csv, strategy, include_index, no_masks):
"""Create split CSV file from dataset directory structure.
Discovers files from directory structure and generates CSV split file.
Supports both static splits (preserve directory structure) and pool mode
(all samples marked as 'pool' for dynamic splitting).
Examples:
# Create pool CSV (all samples as 'pool')
hae dataset split create \\
--data-dir ./datasets/raw/MVTec/mvtec_ad \\
--object-category bottle \\
--output-csv splits/mvtec-bottle-pool.csv \\
--strategy pool
# Create static split CSV (preserve directory structure)
hae dataset split create \\
--data-dir ./datasets/raw/MVTec/mvtec_ad \\
--object-category bottle \\
--output-csv splits/mvtec-bottle-static.csv \\
--strategy static
"""
from pathlib import Path
import sys
import csv
# Import file discovery utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.file_discovery import (
discover_files_from_directory,
detect_structure
)
except ImportError as e:
console.print(f"[red]Error importing file discovery modules: {e}[/red]")
console.print("[yellow]Make sure dioodmi-py is properly installed[/yellow]")
sys.exit(1)
# Validate data directory
data_dir_path = Path(data_dir)
if not data_dir_path.exists():
console.print(f"[red]Data directory not found: {data_dir}[/red]")
sys.exit(1)
# Determine object categories to process
if object_category:
object_categories = [object_category]
else:
# Auto-detect all object categories
object_categories = []
for item in sorted(data_dir_path.iterdir()):
if item.is_dir() and not item.name.startswith('.'):
detected_structure = detect_structure(item)
if detected_structure != "unknown":
object_categories.append(item.name)
if not object_categories:
console.print(f"[red]No valid object categories found in {data_dir}[/red]")
console.print("[yellow]Please specify --object-category explicitly[/yellow]")
sys.exit(1)
# Process each object category
all_rows = []
current_idx = 0
console.print(f"[bold]Creating split CSV file[/bold]")
console.print(f" Data directory: {data_dir}")
console.print(f" Output CSV: {output_csv}")
console.print(f" Strategy: {strategy}")
console.print(f" Object categories: {', '.join(object_categories)}\n")
for obj_cat in object_categories:
try:
df = discover_files_from_directory(
data_dir=data_dir,
object_category=obj_cat,
structure=None, # Auto-detect
include_masks=not no_masks,
split_map=None,
pool_mode=(strategy == 'pool')
)
# Convert DataFrame to rows
for _, row in df.iterrows():
row_dict = row.to_dict()
if include_index:
row_dict[''] = current_idx
current_idx += 1
all_rows.append(row_dict)
# Show summary for this category
train_count = len(df[df['split'] == 'train']) if 'train' in df['split'].values else 0
test_count = len(df[df['split'] == 'test']) if 'test' in df['split'].values else 0
val_count = len(df[df['split'] == 'val']) if 'val' in df['split'].values else 0
pool_count = len(df[df['split'] == 'pool']) if 'pool' in df['split'].values else 0
console.print(f" [{obj_cat}] Found {len(df)} images")
if strategy == 'pool':
console.print(f" All marked as 'pool' for dynamic splitting")
else:
if train_count > 0:
console.print(f" Train: {train_count}")
if val_count > 0:
console.print(f" Val: {val_count}")
if test_count > 0:
console.print(f" Test: {test_count}")
except Exception as e:
console.print(f"[red]Error processing {obj_cat}: {e}[/red]")
continue
if not all_rows:
console.print("[red]No data found to write to CSV[/red]")
sys.exit(1)
# Write CSV file
output_path = Path(output_csv)
# If relative path and doesn't exist, default to definitions/splits/
if not output_path.is_absolute() and not output_path.parent.exists():
# Check if it's just a filename
if output_path.parent == Path('.'):
output_path = Path('definitions/splits') / output_path.name
output_path.parent.mkdir(parents=True, exist_ok=True)
# Determine fieldnames
fieldnames = []
if include_index:
fieldnames.append("")
fieldnames.extend(["object", "split", "label", "image", "mask", "category"])
with open(output_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(all_rows)
console.print(f"\n[green]✓ Created split CSV: {output_path}[/green]")
console.print(f" Total entries: {len(all_rows)}")
@split.command('add')
@click.option('--csv', 'csv_file', type=str, required=True, help='Existing CSV file to update')
@click.option('--from-dir', type=str, help='Discover files from directory')
@click.option('--from-csv', type=str, help='Add samples from another CSV file')
@click.option('--object-category', type=str, help='Only add samples for this object category')
@click.option('--split', type=click.Choice(['train', 'val', 'test', 'pool']), help='Override split assignment (for directory input)')
@click.option('--pool-mode', is_flag=True, help='Mark all new samples as pool (for directory input)')
@click.option('--duplicate-strategy', type=click.Choice(['skip', 'error']), default='error',
help='How to handle duplicates: skip (keep existing) or error (fail)')
@click.option('--output-csv', type=str, help='Write to different file (default: update in-place)')
@click.option('--backup', is_flag=True, help='Create backup before updating')
@click.option('--dry-run', is_flag=True, help='Show what would be added without modifying file')
@click.option('--structure', type=click.Choice(['ok_ng', 'category']), help='Force structure type (default: auto-detect)')
@click.option('--no-masks', is_flag=True, help='Do not search for mask files')
@click.option('--include-index', is_flag=True, help='Include index column in output')
def split_add(csv_file, from_dir, from_csv, object_category, split, pool_mode, duplicate_strategy,
output_csv, backup, dry_run, structure, no_masks, include_index):
"""Add new samples to an existing split CSV file.
Supports adding samples from directory structure or from another CSV file.
Only adds NEW samples - use 'update' command to modify existing entries.
Examples:
# Add new samples from directory
hae dataset split add \\
--csv splits/mvtec-bottle-pool.csv \\
--from-dir ./datasets/raw/MVTec/mvtec_ad/bottle \\
--object-category bottle \\
--split pool \\
--duplicate-strategy skip
# Merge from another CSV
hae dataset split add \\
--csv splits/mvtec-bottle-full.csv \\
--from-csv splits/mvtec-bottle-val-only.csv \\
--duplicate-strategy skip
"""
from pathlib import Path
import sys
# Import utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.split_utils import (
load_split_csv, save_split_csv, resolve_split_csv_path,
find_duplicates, create_backup
)
from dioodmi.data.file_discovery import discover_files_from_directory
except ImportError as e:
console.print(f"[red]Error importing modules: {e}[/red]")
sys.exit(1)
# Validate input source
if not from_dir and not from_csv:
console.print("[red]Error: Must specify either --from-dir or --from-csv[/red]")
sys.exit(1)
if from_dir and from_csv:
console.print("[red]Error: Cannot specify both --from-dir and --from-csv[/red]")
sys.exit(1)
# Load existing CSV
try:
existing_df = load_split_csv(csv_file)
console.print(f"[dim]Loaded existing CSV: {len(existing_df)} samples[/dim]")
except Exception as e:
console.print(f"[red]Error loading CSV: {e}[/red]")
sys.exit(1)
# Discover or load new samples
if from_dir:
# Discover from directory
if not object_category:
console.print("[red]Error: --object-category required when using --from-dir[/red]")
sys.exit(1)
try:
new_df = discover_files_from_directory(
data_dir=from_dir,
object_category=object_category,
structure=structure,
include_masks=not no_masks,
split_map=None,
pool_mode=pool_mode or (split == 'pool')
)
# Override split if specified
if split and not pool_mode:
new_df['split'] = split
console.print(f"[dim]Discovered {len(new_df)} new samples from directory[/dim]")
except Exception as e:
console.print(f"[red]Error discovering files: {e}[/red]")
sys.exit(1)
else: # from_csv
# Load from another CSV
try:
new_df = load_split_csv(from_csv)
console.print(f"[dim]Loaded {len(new_df)} samples from source CSV[/dim]")
# Override split if specified
if split:
new_df['split'] = split
# Filter by object category if specified
if object_category:
new_df = new_df[new_df['object'] == object_category].copy()
console.print(f"[dim]Filtered to {len(new_df)} samples for object '{object_category}'[/dim]")
except Exception as e:
console.print(f"[red]Error loading source CSV: {e}[/red]")
sys.exit(1)
if len(new_df) == 0:
console.print("[yellow]No new samples to add[/yellow]")
return
# Check for duplicates
existing_keys = set(zip(existing_df['object'], existing_df['image']))
new_keys = set(zip(new_df['object'], new_df['image']))
duplicates = existing_keys & new_keys
if duplicates:
if duplicate_strategy == 'error':
console.print(f"[red]Error: Found {len(duplicates)} duplicate entries[/red]")
console.print("[yellow]Use --duplicate-strategy skip to ignore duplicates[/yellow]")
sys.exit(1)
else: # skip
console.print(f"[yellow]Found {len(duplicates)} duplicate entries, skipping[/yellow]")
# Filter out duplicates
new_df = new_df[~new_df.apply(lambda row: (row['object'], row['image']) in duplicates, axis=1)].copy()
console.print(f"[dim]Adding {len(new_df)} new samples (skipped {len(duplicates)} duplicates)[/dim]")
if len(new_df) == 0:
console.print("[yellow]No new samples to add after filtering duplicates[/yellow]")
return
# Combine DataFrames
combined_df = pd.concat([existing_df, new_df], ignore_index=True)
# Statistics
stats = {
'existing': len(existing_df),
'new': len(new_df),
'duplicates_skipped': len(duplicates) if duplicates else 0,
'total': len(combined_df)
}
if dry_run:
console.print("\n[bold]Dry run - no changes made[/bold]")
console.print(f" Existing samples: {stats['existing']}")
console.print(f" New samples to add: {stats['new']}")
if stats['duplicates_skipped'] > 0:
console.print(f" Duplicates skipped: {stats['duplicates_skipped']}")
console.print(f" Total after add: {stats['total']}")
return
# Create backup if requested
if backup:
backup_path = create_backup(csv_file)
console.print(f"[dim]Created backup: {backup_path}[/dim]")
# Determine output path
output_path = Path(output_csv) if output_csv else resolve_split_csv_path(csv_file)
# Save
try:
save_split_csv(combined_df, str(output_path), include_index=include_index)
console.print(f"\n[green]✓ Added {stats['new']} samples to CSV[/green]")
console.print(f" Total samples: {stats['total']}")
if stats['duplicates_skipped'] > 0:
console.print(f" Duplicates skipped: {stats['duplicates_skipped']}")
console.print(f" Saved to: {output_path}")
except Exception as e:
console.print(f"[red]Error saving CSV: {e}[/red]")
sys.exit(1)
@split.command('list')
@click.option('--splits-dir', type=str, default='./definitions/splits/',
help='Directory to search for split CSV files (default: definitions/splits/)')
@click.option('--all', 'list_all', is_flag=True, help='List from both definitions/splits/ and fixtures/splits/')
@click.option('--dataset', type=str, help='Filter by dataset type (e.g., mvtec, sdp, visa)')
@click.option('--object-category', type=str, help='Filter by object category')
@click.option('--split-type', type=click.Choice(['train', 'val', 'test', 'pool']), help='Filter by split type')
@click.option('--pattern', type=str, help='Filter by filename pattern (e.g., *pool*.csv)')
@click.option('--format', type=click.Choice(['table', 'json']), default='table', help='Output format')
@click.option('--detailed', is_flag=True, help='Show detailed statistics')
@click.option('--validate', is_flag=True, help='Validate CSV files')
@click.option('--sort-by', type=click.Choice(['name', 'size', 'date', 'samples']), default='name', help='Sort results')
def split_list(splits_dir, list_all, dataset, object_category, split_type, pattern, format, detailed, validate, sort_by):
"""List all available split CSV files with statistics.
Examples:
# List all split files
hae dataset split list
# List MVTec splits only
hae dataset split list --dataset mvtec
# Show detailed statistics
hae dataset split list --detailed
# Validate all splits
hae dataset split list --validate
"""
from pathlib import Path
import sys
import glob
from datetime import datetime
# Import utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.split_utils import (
load_split_csv, compute_statistics, compute_per_object_stats,
validate_structure, resolve_split_csv_path
)
except ImportError as e:
console.print(f"[red]Error importing modules: {e}[/red]")
sys.exit(1)
# Determine directories to search
if list_all:
# Search both locations
search_dirs = [
Path('definitions/splits'),
Path('fixtures/splits')
]
else:
# Search specified directory
search_dirs = [Path(splits_dir)]
# Find all CSV files from all search directories
csv_files = []
for search_dir in search_dirs:
if not search_dir.exists():
if not list_all: # Only warn if not using --all
console.print(f"[yellow]Splits directory not found: {search_dir}[/yellow]")
continue
if pattern:
csv_files.extend(search_dir.glob(pattern))
else:
csv_files.extend(search_dir.glob("*.csv"))
# Remove duplicates (in case both directories have same files)
csv_files = list(set(csv_files))
if not csv_files:
if list_all:
console.print(f"[yellow]No CSV files found in definitions/splits/ or fixtures/splits/[/yellow]")
else:
console.print(f"[yellow]No CSV files found in {splits_dir}[/yellow]")
return
# Analyze each file
results = []
for csv_file in sorted(csv_files):
try:
df = load_split_csv(str(csv_file))
stats = compute_statistics(df)
# Apply filters
if dataset and dataset.lower() not in csv_file.stem.lower():
continue
if object_category and object_category not in stats['objects']:
continue
if split_type and split_type not in stats['splits']:
continue
# Get file metadata
file_stat = csv_file.stat()
result = {
'filename': csv_file.name,
'path': str(csv_file),
'file_size': file_stat.st_size,
'modified': datetime.fromtimestamp(file_stat.st_mtime),
'total_samples': stats['total_samples'],
'objects': stats['objects'],
'splits': stats['splits'],
'labels': stats['labels'],
'categories': stats['categories'],
'mask_coverage': stats['mask_coverage'],
}
if detailed:
result['per_object'] = compute_per_object_stats(df)
if validate:
validation = validate_structure(df)
result['validation'] = validation
results.append(result)
except Exception as e:
console.print(f"[yellow]Error processing {csv_file.name}: {e}[/yellow]")
continue
if not results:
console.print("[yellow]No matching split files found[/yellow]")
return
# Sort results
if sort_by == 'name':
results.sort(key=lambda x: x['filename'])
elif sort_by == 'size':
results.sort(key=lambda x: x['file_size'], reverse=True)
elif sort_by == 'date':
results.sort(key=lambda x: x['modified'], reverse=True)
elif sort_by == 'samples':
results.sort(key=lambda x: x['total_samples'], reverse=True)
# Format output
if format == 'json':
import json
console.print(json.dumps(results, indent=2, default=str))
else: # table
from rich.table import Table
table = Table(title="Split CSV Files", show_header=True, header_style="bold magenta")
table.add_column("Filename", style="cyan")
table.add_column("Samples", justify="right")
table.add_column("Objects", justify="right")
table.add_column("Splits", style="green")
table.add_column("Labels", style="yellow")
if validate:
table.add_column("Validation", style="red")
for result in results:
# Format splits
splits_str = ", ".join(f"{k}:{v}" for k, v in sorted(result['splits'].items()) if v > 0)
# Format labels
labels_str = ", ".join(f"{k}:{v}" for k, v in sorted(result['labels'].items()))
# Format objects
objects_str = str(len(result['objects'])) if not detailed else ", ".join(result['objects'][:3])
if detailed and len(result['objects']) > 3:
objects_str += f" (+{len(result['objects'])-3})"
row = [
result['filename'],
str(result['total_samples']),
objects_str,
splits_str,
labels_str,
]
if validate:
val_status = "✓" if result.get('validation', {}).get('is_valid', True) else "✗"
row.append(val_status)
table.add_row(*row)
console.print(table)
if detailed:
console.print("\n[bold]Detailed Statistics:[/bold]")
for result in results:
console.print(f"\n[cyan]{result['filename']}[/cyan]")
if 'per_object' in result:
for obj, obj_stats in result['per_object'].items():
console.print(f" {obj}: {obj_stats['total']} samples")
console.print(f" Splits: {obj_stats['splits']}")
console.print(f" Labels: {obj_stats['labels']}")
@split.command('stats')
@click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to analyze')
@click.option('--format', type=click.Choice(['table', 'json']), default='table', help='Output format')
@click.option('--per-object', is_flag=True, help='Show per-object breakdown')
@click.option('--per-category', is_flag=True, help='Show per-category breakdown')
@click.option('--per-split', is_flag=True, help='Show per-split breakdown')
def split_stats(csv_file, format, per_object, per_category, per_split):
"""Show detailed statistics for a single split CSV file.
Examples:
# Basic stats
hae dataset split stats --csv splits/mvtec-bottle-pool.csv
# Detailed per-object breakdown
hae dataset split stats --csv splits/mvtec-split.csv --per-object
"""
from pathlib import Path
import sys
# Import utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.split_utils import (
load_split_csv, compute_statistics, compute_per_object_stats,
compute_per_category_stats, compute_per_split_stats, resolve_split_csv_path
)
except ImportError as e:
console.print(f"[red]Error importing modules: {e}[/red]")
sys.exit(1)
# Load CSV
try:
df = load_split_csv(csv_file)
except Exception as e:
console.print(f"[red]Error loading CSV: {e}[/red]")
sys.exit(1)
# Compute statistics
stats = compute_statistics(df)
result = {
'filename': Path(csv_file).name,
'path': str(resolve_split_csv_path(csv_file)),
'total_samples': stats['total_samples'],
'objects': stats['objects'],
'splits': stats['splits'],
'labels': stats['labels'],
'categories': stats['categories'],
'mask_coverage': stats['mask_coverage'],
}
if per_object:
result['per_object'] = compute_per_object_stats(df)
if per_category:
result['per_category'] = compute_per_category_stats(df)
if per_split:
result['per_split'] = compute_per_split_stats(df)
# Format output
if format == 'json':
import json
console.print(json.dumps(result, indent=2, default=str))
else: # table
console.print(f"\n[bold]Statistics for: {result['filename']}[/bold]")
console.print(f" Total samples: {result['total_samples']}")
console.print(f" Objects: {', '.join(result['objects'])}")
console.print(f" Mask coverage: {result['mask_coverage']:.1%}")
console.print(f"\n[bold]Splits:[/bold]")
for split, count in sorted(result['splits'].items()):
console.print(f" {split}: {count}")
console.print(f"\n[bold]Labels:[/bold]")
for label, count in sorted(result['labels'].items()):
console.print(f" {label}: {count}")
if per_object:
console.print(f"\n[bold]Per-Object Breakdown:[/bold]")
for obj, obj_stats in result['per_object'].items():
console.print(f" {obj}: {obj_stats['total']} samples")
console.print(f" Splits: {obj_stats['splits']}")
console.print(f" Labels: {obj_stats['labels']}")
if per_category:
console.print(f"\n[bold]Per-Category Breakdown:[/bold]")
for cat, cat_stats in sorted(result['per_category'].items()):
console.print(f" {cat}: {cat_stats['total']} samples")
console.print(f" Splits: {cat_stats['splits']}")
if per_split:
console.print(f"\n[bold]Per-Split Breakdown:[/bold]")
for split, split_stats in sorted(result['per_split'].items()):
console.print(f" {split}: {split_stats['total']} samples")
console.print(f" Objects: {split_stats['objects']}")
@split.command('update')
@click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to update')
@click.option('--filter-object', type=str, help='Filter by object category')
@click.option('--filter-split', type=str, help='Filter by split')
@click.option('--filter-label', type=str, help='Filter by label')
@click.option('--filter-category', type=str, help='Filter by category')
@click.option('--filter-image', type=str, help='Filter by image path')
@click.option('--set-split', type=click.Choice(['train', 'val', 'test', 'pool']), help='Set new split value')
@click.option('--set-label', type=click.Choice(['normal', 'anomaly']), help='Set new label value')
@click.option('--set-category', type=str, help='Set new category value')
@click.option('--output-csv', type=str, help='Write to different file (default: update in-place)')
@click.option('--backup', is_flag=True, help='Create backup before updating')
@click.option('--dry-run', is_flag=True, help='Show what would be updated without modifying file')
@click.option('--interactive', is_flag=True, help='Interactive mode: confirm each update')
def split_update(csv_file, filter_object, filter_split, filter_label, filter_category, filter_image,
set_split, set_label, set_category, output_csv, backup, dry_run, interactive):
"""Update existing entries in split CSV file.
Modifies EXISTING samples only - use 'add' command to add new samples.
Requires at least one filter and one set operation.
Examples:
# Move all pool samples to train
hae dataset split update \\
--csv splits/mvtec-bottle-pool.csv \\
--filter-split pool \\
--set-split train
# Update specific object category
hae dataset split update \\
--csv splits/mvtec-split.csv \\
--filter-object bottle \\
--filter-split pool \\
--set-split train
"""
from pathlib import Path
import sys
# Import utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.split_utils import (
load_split_csv, save_split_csv, resolve_split_csv_path,
apply_filters, create_backup
)
except ImportError as e:
console.print(f"[red]Error importing modules: {e}[/red]")
sys.exit(1)
# Validate filters and set operations
filters = {
'filter_object': filter_object,
'filter_split': filter_split,
'filter_label': filter_label,
'filter_category': filter_category,
'filter_image': filter_image,
}
filters = {k: v for k, v in filters.items() if v is not None}
set_ops = {
'set_split': set_split,
'set_label': set_label,
'set_category': set_category,
}
set_ops = {k: v for k, v in set_ops.items() if v is not None}
if not filters:
console.print("[red]Error: At least one filter is required[/red]")
sys.exit(1)
if not set_ops:
console.print("[red]Error: At least one set operation is required[/red]")
sys.exit(1)
# Load CSV
try:
df = load_split_csv(csv_file)
console.print(f"[dim]Loaded CSV: {len(df)} samples[/dim]")
except Exception as e:
console.print(f"[red]Error loading CSV: {e}[/red]")
sys.exit(1)
# Apply filters
filtered_df = apply_filters(df, **filters)
if len(filtered_df) == 0:
console.print("[yellow]No samples match the filter criteria[/yellow]")
return
console.print(f"[dim]Found {len(filtered_df)} samples matching filters[/dim]")
# Apply updates
updated_df = df.copy()
update_mask = apply_filters(updated_df, **filters).index
if 'set_split' in set_ops:
updated_df.loc[update_mask, 'split'] = set_ops['set_split']
if 'set_label' in set_ops:
updated_df.loc[update_mask, 'label'] = set_ops['set_label']
if 'set_category' in set_ops:
updated_df.loc[update_mask, 'category'] = set_ops['set_category']
# Show what will be updated
if dry_run or interactive:
console.print("\n[bold]Samples to be updated:[/bold]")
for idx in update_mask[:10]: # Show first 10
row = updated_df.loc[idx]
console.print(f" {row['object']}/{row['image']}")
if 'set_split' in set_ops:
console.print(f" split: {df.loc[idx, 'split']} → {set_ops['set_split']}")
if 'set_label' in set_ops:
console.print(f" label: {df.loc[idx, 'label']} → {set_ops['set_label']}")
if 'set_category' in set_ops:
console.print(f" category: {df.loc[idx, 'category']} → {set_ops['set_category']}")
if len(update_mask) > 10:
console.print(f" ... and {len(update_mask) - 10} more")
if dry_run:
console.print("\n[bold]Dry run - no changes made[/bold]")
return
if interactive:
if not click.confirm(f"\nUpdate {len(update_mask)} samples?"):
console.print("[yellow]Cancelled[/yellow]")
return
# Create backup if requested
if backup:
backup_path = create_backup(csv_file)
console.print(f"[dim]Created backup: {backup_path}[/dim]")
# Determine output path
output_path = Path(output_csv) if output_csv else resolve_split_csv_path(csv_file)
# Save
try:
save_split_csv(updated_df, str(output_path))
console.print(f"\n[green]✓ Updated {len(update_mask)} samples[/green]")
console.print(f" Saved to: {output_path}")
except Exception as e:
console.print(f"[red]Error saving CSV: {e}[/red]")
sys.exit(1)
@split.command('remove')
@click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to update')
@click.option('--filter-object', type=str, help='Filter by object category')
@click.option('--filter-split', type=str, help='Filter by split')
@click.option('--filter-label', type=str, help='Filter by label')
@click.option('--filter-category', type=str, help='Filter by category')
@click.option('--filter-image', type=str, help='Filter by image path (exact match or pattern)')
@click.option('--from-csv', type=str, help='Remove samples matching entries in another CSV')
@click.option('--output-csv', type=str, help='Write to different file (default: update in-place)')
@click.option('--backup', is_flag=True, help='Create backup before removing')
@click.option('--dry-run', is_flag=True, help='Show what would be removed without modifying file')
@click.option('--interactive', is_flag=True, help='Interactive mode: confirm removal')
def split_remove(csv_file, filter_object, filter_split, filter_label, filter_category, filter_image,
from_csv, output_csv, backup, dry_run, interactive):
"""Remove samples from split CSV file.
Examples:
# Remove specific image
hae dataset split remove \\
--csv splits/mvtec-split.csv \\
--filter-image bottle/train/good/001.png
# Remove all samples from specific split
hae dataset split remove \\
--csv splits/mvtec-split.csv \\
--filter-split pool
"""
from pathlib import Path
import sys
# Import utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.split_utils import (
load_split_csv, save_split_csv, resolve_split_csv_path,
apply_filters, create_backup
)
except ImportError as e:
console.print(f"[red]Error importing modules: {e}[/red]")
sys.exit(1)
# Load CSV
try:
df = load_split_csv(csv_file)
console.print(f"[dim]Loaded CSV: {len(df)} samples[/dim]")
except Exception as e:
console.print(f"[red]Error loading CSV: {e}[/red]")
sys.exit(1)
# Determine what to remove
if from_csv:
# Remove samples matching entries in another CSV
try:
remove_df = load_split_csv(from_csv)
remove_keys = set(zip(remove_df['object'], remove_df['image']))
remove_mask = df.apply(lambda row: (row['object'], row['image']) in remove_keys, axis=1)
except Exception as e:
console.print(f"[red]Error loading source CSV: {e}[/red]")
sys.exit(1)
else:
# Apply filters
filters = {
'filter_object': filter_object,
'filter_split': filter_split,
'filter_label': filter_label,
'filter_category': filter_category,
'filter_image': filter_image,
}
filters = {k: v for k, v in filters.items() if v is not None}
if not filters:
console.print("[red]Error: At least one filter or --from-csv is required[/red]")
sys.exit(1)
filtered_df = apply_filters(df, **filters)
remove_mask = df.index.isin(filtered_df.index)
if not remove_mask.any():
console.print("[yellow]No samples match the removal criteria[/yellow]")
return
num_to_remove = remove_mask.sum()
console.print(f"[dim]Found {num_to_remove} samples to remove[/dim]")
if num_to_remove == len(df):
console.print("[yellow]Warning: This will remove ALL samples from the CSV[/yellow]")
if not interactive and not dry_run:
if not click.confirm("Continue?"):
console.print("[yellow]Cancelled[/yellow]")
return
# Show what will be removed
if dry_run or interactive:
console.print("\n[bold]Samples to be removed:[/bold]")
to_remove = df[remove_mask]
for idx, row in to_remove.head(10).iterrows():
console.print(f" {row['object']}/{row['image']} ({row['split']}, {row['label']})")
if num_to_remove > 10:
console.print(f" ... and {num_to_remove - 10} more")
if dry_run:
console.print("\n[bold]Dry run - no changes made[/bold]")
return
if interactive:
if not click.confirm(f"\nRemove {num_to_remove} samples?"):
console.print("[yellow]Cancelled[/yellow]")
return
# Create backup if requested
if backup:
backup_path = create_backup(csv_file)
console.print(f"[dim]Created backup: {backup_path}[/dim]")
# Remove samples
updated_df = df[~remove_mask].copy()
# Determine output path
output_path = Path(output_csv) if output_csv else resolve_split_csv_path(csv_file)
# Save
try:
save_split_csv(updated_df, str(output_path))
console.print(f"\n[green]✓ Removed {num_to_remove} samples[/green]")
console.print(f" Remaining samples: {len(updated_df)}")
console.print(f" Saved to: {output_path}")
except Exception as e:
console.print(f"[red]Error saving CSV: {e}[/red]")
sys.exit(1)
@split.command('validate')
@click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to validate')
@click.option('--check-duplicates', is_flag=True, help='Check for duplicate entries')
@click.option('--check-paths', is_flag=True, help='Check path formats')
@click.option('--check-structure', is_flag=True, help='Check CSV structure')
@click.option('--fix', is_flag=True, help='Auto-fix issues where possible')
@click.option('--output-csv', type=str, help='Write fixed CSV to different file')
@click.option('--output-report', type=str, help='Write validation report to file')
@click.option('--format', type=click.Choice(['table', 'json']), default='table', help='Output format')
def split_validate(csv_file, check_duplicates, check_paths, check_structure, fix,
output_csv, output_report, format):
"""Validate split CSV file structure.
Examples:
# Full validation
hae dataset split validate \\
--csv splits/mvtec-split.csv \\
--check-duplicates --check-paths --check-structure
# Auto-fix issues
hae dataset split validate \\
--csv splits/mvtec-split.csv \\
--fix \\
--output-csv splits/mvtec-split-fixed.csv
"""
from pathlib import Path
import sys
# Import utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.split_utils import (
load_split_csv, save_split_csv, validate_structure,
find_duplicates, remove_duplicates
)
except ImportError as e:
console.print(f"[red]Error importing modules: {e}[/red]")
sys.exit(1)
# Load CSV (this already validates schema)
try:
df = load_split_csv(csv_file)
except Exception as e:
console.print(f"[red]Error loading CSV: {e}[/red]")
sys.exit(1)
# Run validations
results = {
'is_valid': True,
'errors': [],
'warnings': [],
'fixes_applied': [],
}
# Structure validation (always run)
if check_structure or not (check_duplicates or check_paths):
structure = validate_structure(df)
results['is_valid'] = structure['is_valid']
results['errors'].extend(structure.get('errors', []))
results['warnings'].extend(structure.get('warnings', []))
# Duplicate check
if check_duplicates:
duplicates = find_duplicates(df)
if len(duplicates) > 0:
results['is_valid'] = False
results['warnings'].append(f"Found {len(duplicates)} duplicate entries")
if fix:
df = remove_duplicates(df, strategy='keep-first')
results['fixes_applied'].append(f"Removed {len(duplicates)} duplicate entries")
# Path check
if check_paths:
invalid_paths = df['image'].isnull().sum() + (df['image'] == '').sum()
if invalid_paths > 0:
results['is_valid'] = False
results['errors'].append(f"Found {invalid_paths} invalid image paths")
if fix:
# Remove rows with invalid paths
df = df[df['image'].notna() & (df['image'] != '')].copy()
results['fixes_applied'].append(f"Removed {invalid_paths} rows with invalid paths")
# Format output
if format == 'json':
import json
console.print(json.dumps(results, indent=2, default=str))
else: # table
if results['is_valid']:
console.print("[green]✓ CSV file is valid[/green]")
else:
console.print("[red]✗ CSV file has issues[/red]")
if results['errors']:
console.print("\n[bold red]Errors:[/bold red]")
for error in results['errors']:
console.print(f" - {error}")
if results['warnings']:
console.print("\n[bold yellow]Warnings:[/bold yellow]")
for warning in results['warnings']:
console.print(f" - {warning}")
if results['fixes_applied']:
console.print("\n[bold green]Fixes Applied:[/bold green]")
for fix_applied in results['fixes_applied']:
console.print(f" - {fix_applied}")
# Save fixed CSV if requested
if fix and results['fixes_applied'] and output_csv:
try:
save_split_csv(df, output_csv)
console.print(f"\n[green]✓ Saved fixed CSV to: {output_csv}[/green]")
except Exception as e:
console.print(f"[red]Error saving fixed CSV: {e}[/red]")
# Save report if requested
if output_report:
import json
with open(output_report, 'w') as f:
json.dump(results, f, indent=2, default=str)
console.print(f"\n[dim]Validation report saved to: {output_report}[/dim]")
# Exit with error code if invalid
if not results['is_valid']:
sys.exit(1)
@split.command('apply')
@click.option('--csv', 'csv_file', type=str, required=True, help='Pool CSV file to split')
@click.option('--output', 'output_csv', type=str, required=True, help='Output CSV file path')
@click.option('--strategy', type=click.Choice(['random', 'ratio', 'leave-k-out', 'k-fold']), required=True,
help='Splitting strategy')
@click.option('--train-ratio', type=float, help='Train ratio (for random/ratio strategy)')
@click.option('--val-ratio', type=float, help='Validation ratio (for random/ratio strategy)')
@click.option('--test-ratio', type=float, help='Test ratio (for random/ratio strategy)')
@click.option('--k', type=int, help='Number of samples to leave out (for leave-k-out)')
@click.option('--n-folds', type=int, help='Number of folds (for k-fold)')
@click.option('--fold-idx', type=int, default=0, help='Fold index to use as test (for k-fold)')
@click.option('--from-split', type=str, default='pool', help='Source split (for leave-k-out)')
@click.option('--to-split', type=str, default='val', help='Target split (for leave-k-out)')
@click.option('--seed', type=int, default=42, help='Random seed')
@click.option('--dry-run', is_flag=True, help='Show what would be created without modifying files')
def split_apply(csv_file, output_csv, strategy, train_ratio, val_ratio, test_ratio,
k, n_folds, fold_idx, from_split, to_split, seed, dry_run):
"""Apply splitting strategy to convert pool CSV into train/val/test splits.
Examples:
# Random split with ratios
hae dataset split apply \\
--csv splits/mvtec-bottle-pool.csv \\
--output splits/mvtec-bottle-split.csv \\
--strategy random \\
--train-ratio 0.7 --val-ratio 0.15 --test-ratio 0.15 \\
--seed 42
# Leave-k-out
hae dataset split apply \\
--csv splits/mvtec-bottle-pool.csv \\
--output splits/mvtec-bottle-split.csv \\
--strategy leave-k-out \\
--k 20 \\
--from-split pool \\
--to-split val
"""
from pathlib import Path
import sys
# Import utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.split_utils import load_split_csv, save_split_csv
from dioodmi.data.split_strategies import (
apply_random_split, apply_leave_k_out_split, apply_k_fold_split
)
except ImportError as e:
console.print(f"[red]Error importing modules: {e}[/red]")
sys.exit(1)
# Load CSV
try:
df = load_split_csv(csv_file)
console.print(f"[dim]Loaded pool CSV: {len(df)} samples[/dim]")
except Exception as e:
console.print(f"[red]Error loading CSV: {e}[/red]")
sys.exit(1)
# Apply strategy
if strategy == 'random' or strategy == 'ratio':
if train_ratio is None or val_ratio is None or test_ratio is None:
console.print("[red]Error: --train-ratio, --val-ratio, and --test-ratio required for random/ratio strategy[/red]")
sys.exit(1)
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
console.print("[red]Error: Ratios must sum to 1.0[/red]")
sys.exit(1)
ratios = {'train': train_ratio, 'val': val_ratio, 'test': test_ratio}
result_df = apply_random_split(df, ratios, seed=seed, respect_existing=False)
elif strategy == 'leave-k-out':
if k is None:
console.print("[red]Error: --k required for leave-k-out strategy[/red]")
sys.exit(1)
result_df = apply_leave_k_out_split(
df, k=k, from_split=from_split, to_split=to_split,
seed=seed, respect_existing=False
)
elif strategy == 'k-fold':
if n_folds is None:
console.print("[red]Error: --n-folds required for k-fold strategy[/red]")
sys.exit(1)
result_df = apply_k_fold_split(
df, n_folds=n_folds, fold_idx=fold_idx, seed=seed,
respect_existing=False
)
# Show results
splits = result_df['split'].value_counts().to_dict()
console.print("\n[bold]Split distribution:[/bold]")
for split, count in sorted(splits.items()):
console.print(f" {split}: {count}")
if dry_run:
console.print("\n[bold]Dry run - no changes made[/bold]")
return
# Save
try:
save_split_csv(result_df, output_csv)
console.print(f"\n[green]✓ Created split CSV: {output_csv}[/green]")
console.print(f" Total samples: {len(result_df)}")
except Exception as e:
console.print(f"[red]Error saving CSV: {e}[/red]")
sys.exit(1)
@dataset.command('show')
@click.argument('dataset_name')
@click.option('--object-category', type=str, required=True, help='Object category to show details for')
@click.option('--data-dir', type=str, help='Data directory path (default: ./datasets/raw/MVTec/mvtec_ad)')
@click.option('--csv-split-file', type=str, help='CSV split file path')
def dataset_show(dataset_name, object_category, data_dir, csv_split_file):
"""Show detailed information about a specific dataset and category.
Displays comprehensive statistics including:
- Dataset configuration
- Split breakdowns (train/val/test)
- CSV split file details
- Sample paths and counts
- Configuration parameters
Examples:
hae dataset show mvtec --object-category bottle
hae dataset show mvtec --object-category bottle --data-dir ./my-data
hae dataset show mvtec --object-category bottle --csv-split-file splits/custom.csv
"""
from pathlib import Path
import sys
import os
# Import dataset utilities
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src"))
try:
from dioodmi.data.dataset_configs import BASE_DATASET_REGISTRY, get_base_dataset_config
from dioodmi.data.anomaly_core import DatasetConfig, ProcessingParams, parse_csv_data
except ImportError as e:
console.print(f"[red]Error importing dataset modules: {e}[/red]")
console.print("[yellow]Make sure dioodmi-py is properly installed[/yellow]")
sys.exit(1)
# Validate dataset
if dataset_name not in BASE_DATASET_REGISTRY:
console.print(f"[red]Unknown dataset: {dataset_name}[/red]")
console.print(f"[yellow]Available datasets: {', '.join(BASE_DATASET_REGISTRY.keys())}[/yellow]")
sys.exit(1)
base_config = BASE_DATASET_REGISTRY[dataset_name]
# Validate object category
if object_category not in base_config.object_classes:
console.print(f"[red]Unknown category '{object_category}' for dataset '{dataset_name}'[/red]")
console.print(f"[yellow]Available categories: {', '.join(base_config.object_classes.keys())}[/yellow]")
sys.exit(1)
# Default data directory
if not data_dir:
data_dir = './datasets/raw/MVTec/mvtec_ad'
# Use provided CSV split file or default
csv_file = csv_split_file if csv_split_file else base_config.csv_split_file
# Display header
console.print(f"\n[bold cyan]Dataset: {dataset_name} / Category: {object_category}[/bold cyan]")
# Show configuration
console.print(f"\n[bold]Configuration:[/bold]")
console.print(f" Dataset Name: {base_config.name}")
console.print(f" Object Category: {object_category}")
console.print(f" Category Index: {base_config.object_classes[object_category]}")
console.print(f" Data Directory: {data_dir}")
console.print(f" CSV Split File: {csv_file}")
# Get original image dimensions from actual files
original_dims = None
try:
# Try to load a sample image to get dimensions
config = DatasetConfig(
name=base_config.name,
object_classes=base_config.object_classes,
csv_columns=base_config.csv_columns,
default_image_size=base_config.default_image_size,
default_crop_size=base_config.default_crop_size,
csv_split_file=csv_file,
strategies=base_config.strategies,
temporal_strategies=base_config.temporal_strategies,
mode='train',
object_class=object_category,
rootdir=data_dir,
anomaly_class='good'
)
params = ProcessingParams(
image_size=base_config.default_image_size,
crop_size=base_config.default_crop_size
)
data_df = parse_csv_data(config, params)
if len(data_df) > 0:
# Get image column name
image_col = None
for col in ['image', 'image_path', 'filepath', 'path']:
if col in data_df.columns:
image_col = col
break
if image_col:
# Try to load first image to get dimensions
first_image_path = data_df[image_col].iloc[0]
# Resolve path relative to data directory
if not os.path.isabs(first_image_path):
# Path in CSV is relative, try multiple locations
possible_paths = [
os.path.join(data_dir, first_image_path),
os.path.join(data_dir, base_config.name, first_image_path),
first_image_path # Try as-is
]
else:
possible_paths = [first_image_path]
for full_path in possible_paths:
full_path_abs = os.path.abspath(full_path)
if os.path.exists(full_path_abs):
try:
from PIL import Image
with Image.open(full_path_abs) as img:
original_dims = img.size # (width, height)
break
except Exception as e:
# Silently continue to next path
continue
except Exception as e:
# Silently fail - original dimensions are optional
pass
# Show image dimensions
console.print(f"\n[bold]Image Dimensions:[/bold]")
if original_dims:
console.print(f" Original Size: {original_dims[0]}×{original_dims[1]} (width×height)")
else:
console.print(f" Original Size: [dim]Unable to determine (check data directory)[/dim]")
console.print(f" Training Config:")
console.print(f" - Image Size: {base_config.default_image_size} (resize target)")
console.print(f" - Crop Size: {base_config.default_crop_size} (crop target)")
# Show split statistics
console.print(f"\n[bold]Split Statistics:[/bold]")
split_stats = {}
for split in ['train', 'val', 'test']:
try:
# Create dataset config
config = DatasetConfig(
name=base_config.name,
object_classes=base_config.object_classes,
csv_columns=base_config.csv_columns,
default_image_size=base_config.default_image_size,
default_crop_size=base_config.default_crop_size,
csv_split_file=csv_file,
strategies=base_config.strategies,
temporal_strategies=base_config.temporal_strategies,
mode=split,
object_class=object_category,
rootdir=data_dir,
anomaly_class='good' if split == 'train' else 'all'
)
# Create processing params
params = ProcessingParams(
image_size=base_config.default_image_size,
crop_size=base_config.default_crop_size
)
# Load dataset metadata
data_df = parse_csv_data(config, params)
count = len(data_df)
split_stats[split] = {
'count': count,
'dataframe': data_df
}
# Show anomaly class breakdown for test split
if split == 'test' and 'category' in data_df.columns:
anomaly_counts = data_df['category'].value_counts().to_dict()
console.print(f" {split.capitalize()}: {count} images")
for anomaly_class, anomaly_count in sorted(anomaly_counts.items()):
console.print(f" - {anomaly_class}: {anomaly_count}")
else:
console.print(f" {split.capitalize()}: {count} images")
except Exception as e:
console.print(f" {split.capitalize()}: [red]Error - {str(e)}[/red]")
split_stats[split] = {'count': None, 'error': str(e)}
# Show sample paths (first few from train split)
if 'train' in split_stats and split_stats['train']['count'] and split_stats['train']['count'] > 0:
train_df = split_stats['train']['dataframe']
# Check for common column names for image paths
image_col = None
for col in ['image', 'image_path', 'filepath', 'path']:
if col in train_df.columns:
image_col = col
break
if image_col:
console.print(f"\n[bold]Sample Image Paths (train, first 5):[/bold]")
for i, path in enumerate(train_df[image_col].head(5), 1):
console.print(f" {i}. {path}")
if len(train_df) > 5:
console.print(f" ... and {len(train_df) - 5} more")
# Show total statistics
total_images = sum(s['count'] for s in split_stats.values() if s.get('count') is not None)
console.print(f"\n[bold]Total Images:[/bold] {total_images}")
# Show CSV file location
csv_path = Path(csv_file)
if not csv_path.is_absolute():
# Try relative to splits directory
splits_path = Path("splits") / csv_file
if splits_path.exists():
csv_path = splits_path
else:
csv_path = Path(csv_file)
if csv_path.exists():
console.print(f"\n[bold]CSV Split File:[/bold] {csv_path.absolute()}")
console.print(f"[dim]File exists: Yes[/dim]")
else:
console.print(f"\n[bold]CSV Split File:[/bold] {csv_file}")
console.print(f"[yellow]File exists: No (may use default dataset structure)[/yellow]")
@experiment.command('clean')
@click.argument('experiment_id')
@click.option('--all', is_flag=True, help='Clean all results (WARNING: destructive)')
def experiment_clean(experiment_id, all):
"""Clean temporary files or all results.
By default, cleans temporary files (__pycache__, *.tmp, .ipynb_checkpoints).
Use --all to delete all execution results (WARNING: destructive).
Examples:
hae experiment clean train-and-compare-evaluators
hae experiment clean train-and-compare-evaluators --all
"""
import shutil
# For orchestration format experiments, clean execution directories
executions_dir = Path("experiment-history")
if executions_dir.exists():
# Find all executions matching this experiment ID
matching_execs = []
exp_id_norm = experiment_id.replace("-", "_").replace(" ", "_")
for exec_dir in executions_dir.rglob("*"):
if exec_dir.is_dir() and "_" in exec_dir.name:
parts = exec_dir.name.split("_", 1)
if len(parts) == 2:
exec_name = parts[1]
exec_name_norm = exec_name.replace("-", "_").replace(" ", "_")
if exec_name_norm == exp_id_norm or exec_name == experiment_id:
matching_execs.append(exec_dir)
if matching_execs:
if all:
if click.confirm(f"⚠️ Delete ALL {len(matching_execs)} execution(s) for {experiment_id}?", abort=True):
for exec_dir in matching_execs:
shutil.rmtree(exec_dir)
console.print(f"[green]✅ Deleted {len(matching_execs)} execution(s)[/green]")
return
else:
# Clean temporary files only
cleaned_count = 0
for exec_dir in matching_execs:
for pattern in ["**/__pycache__", "**/*.tmp", "**/.ipynb_checkpoints"]:
for path in exec_dir.glob(pattern):
if path.is_dir():
shutil.rmtree(path)
cleaned_count += 1
else:
path.unlink()
cleaned_count += 1
if cleaned_count > 0:
console.print(f"[green]✅ Cleaned {cleaned_count} temporary file(s) from {len(matching_execs)} execution(s)[/green]")
else:
console.print(f"[dim]No temporary files found in {len(matching_execs)} execution(s)[/dim]")
return
# Fallback: try alternative experiment directory structure
exp_dir = Path("experiments") / experiment_id
if exp_dir.exists():
if all:
if click.confirm(f"⚠️ Delete ALL results for {experiment_id}?", abort=True):
results_dir = exp_dir / "results"
if results_dir.exists():
shutil.rmtree(results_dir)
console.print("[green]✅ All results deleted[/green]")
else:
# Clean temporary files only
cleaned_count = 0
for pattern in ["**/__pycache__", "**/*.tmp", "**/.ipynb_checkpoints"]:
for path in exp_dir.glob(pattern):
if path.is_dir():
shutil.rmtree(path)
cleaned_count += 1
else:
path.unlink()
cleaned_count += 1
if cleaned_count > 0:
console.print(f"[green]✅ Cleaned {cleaned_count} temporary file(s)[/green]")
else:
console.print("[dim]No temporary files found[/dim]")
else:
console.print(f"[yellow]No experiment found: {experiment_id}[/yellow]")
console.print(f"[dim]Tried:[/dim]")
console.print(f" - Execution directories in experiment-history/")
console.print(f" - Alternative experiment directory: {exp_dir}")
@experiment.command('sync')
@click.option('--full', is_flag=True, help='Full rebuild (clears and re-indexes everything)')
@click.option('--verify', is_flag=True, help='Verify database integrity')
def experiment_sync(full, verify):
"""Sync experiment database with YAML files and execution history.
By default, performs smart incremental sync (only changed files).
Safe to run frequently (e.g., after git pull).
The database can always be rebuilt from YAML files (ground truth).
Feel free to delete .hae/ directory - sync will recreate it.
Examples:
hae experiment sync # Smart sync (fast, incremental)
hae experiment sync --full # Full rebuild (slower)
hae experiment sync --verify # Check database integrity
"""
from hae.db import ExperimentDB
from hae.migration import sync_database_incremental
db = ExperimentDB()
if verify:
console.print("[bold]Verifying database integrity...[/bold]")
db.verify()
db.close()
return
if full:
# Full rebuild - clear and re-index everything
console.print("[bold]Full database rebuild...[/bold]")
db.cursor.execute("DELETE FROM experiment_definitions")
db.cursor.execute("DELETE FROM experiment_executions")
db.cursor.execute("DELETE FROM links")
db.conn.commit()
console.print(" [dim]Indexing definitions...[/dim]")
db.index_definitions()
console.print(" [dim]Indexing execution history...[/dim]")
db.index_executions()
def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0]
exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0]
console.print(f"[green]✓ Full rebuild complete ({def_count} definitions, {exec_count} executions)[/green]")
else:
# Smart incremental sync
console.print("[bold]Syncing experiment database...[/bold]")
# Check if database is empty (first run or deleted)
def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0]
if def_count == 0:
# Empty database - do full index
console.print(" [dim]Database empty - performing full index...[/dim]")
db.index_definitions()
db.index_executions()
def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0]
exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0]
console.print(f"[green]✓ Indexed {def_count} definitions, {exec_count} executions[/green]")
else:
# Incremental sync - only changed files
new, updated, unchanged = sync_database_incremental(db)
# Also sync executions (always incremental)
old_exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0]
db.index_executions()
new_exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0]
new_execs = new_exec_count - old_exec_count
console.print(f"[green]✓ Definitions:[/green] {new} new, {updated} updated, {unchanged} unchanged")
if new_execs > 0:
console.print(f"[green]✓ Executions:[/green] {new_execs} new")
else:
console.print(f"[dim] Executions: No new records[/dim]")
console.print(f"\n[dim]Database: {db.db_path}[/dim]")
db.close()
@experiment.command('doctor')
@click.option('--migration-only', is_flag=True, help='Check only migration status (quick)')
@click.option('--fix', is_flag=True, help='Auto-fix issues where possible')
@click.option('--verbose', is_flag=True, help='Show detailed diagnostics')
def experiment_doctor(migration_only, fix, verbose):
"""Comprehensive health check for experiment system.
Checks migration status, validates experiments, and verifies database.
This is the recommended command to run after pulling new changes.
Examples:
hae experiment doctor # Full health check
hae experiment doctor --migration-only # Quick migration status
hae experiment doctor --fix # Auto-repair issues
hae experiment doctor --verbose # Detailed diagnostics
"""
from hae.migration import (
check_migration_status,
find_old_path_references,
validate_yaml_syntax,
validate_dependencies,
validate_file_references,
fix_yaml_paths
)
from hae.db import ExperimentDB
from rich.table import Table
console.print("[bold cyan]🏥 DiOodMi Experiment Health Check[/bold cyan]\n")
# Track overall status
errors = 0
warnings = 0
# ========================================================================
# 1. Migration Status Check
# ========================================================================
console.print("[bold]Migration Status:[/bold]")
status = check_migration_status()
if status.old_definitions_exists:
console.print(f" [red]❌ OLD:[/red] experiments/definitions/ ({status.old_definitions_count} files)")
errors += 1
if status.old_history_exists:
console.print(f" [red]❌ OLD:[/red] experiments/history/ ({status.old_history_count} records)")
errors += 1
if status.new_definitions_exists:
console.print(f" [green]✅ NEW:[/green] definitions/ ({status.new_definitions_count} files)")
else:
console.print(f" [yellow]⚠️ NEW:[/yellow] definitions/ not found")
warnings += 1
if status.new_history_exists:
console.print(f" [dim] NEW: experiment-history/ ({status.new_history_count} records)[/dim]")
if status.needs_migration:
console.print(f" [cyan]💡 Action:[/cyan] Run 'hae experiment migrate'\n")
else:
console.print(f" [green]✓ Migration complete[/green]\n")
# If migration-only flag, stop here
if migration_only:
summary = f"Migration {'needed' if status.needs_migration else 'complete'}"
console.print(f"[dim]{summary}[/dim]")
sys.exit(1 if status.needs_migration else 0)
# ========================================================================
# 2. Experiment Definitions Validation
# ========================================================================
console.print("[bold]Experiment Definitions:[/bold]")
yaml_dir = Path("definitions/experiments")
if not yaml_dir.exists():
console.print(" [yellow]⚠️ definitions/experiments/ not found[/yellow]\n")
warnings += 1
else:
# YAML syntax validation
yaml_files = list(yaml_dir.glob("*.yaml"))
valid_count = 0
invalid_files = []
for yaml_file in yaml_files:
result = validate_yaml_syntax(yaml_file)
if result.valid:
valid_count += 1
else:
invalid_files.append((yaml_file.name, result.error))
if invalid_files:
console.print(f" [red]❌ YAML Syntax ({valid_count}/{len(yaml_files)} valid):[/red]")
for filename, error in invalid_files[:5]: # Show first 5
error_short = error[:60] + "..." if len(error) > 60 else error
console.print(f" • {filename}: {error_short}")
if len(invalid_files) > 5:
console.print(f" ... and {len(invalid_files) - 5} more")
errors += len(invalid_files)
else:
console.print(f" [green]✅ YAML Syntax ({len(yaml_files)}/{len(yaml_files)} valid)[/green]")
# Dependency validation
dep_issues = validate_dependencies(yaml_dir)
if dep_issues:
console.print(f" [red]❌ Dependencies ({len(dep_issues)} issues):[/red]")
for issue in dep_issues[:5]: # Show first 5
console.print(f" • {issue.experiment_id} → {issue.missing_dependency} ({issue.dependency_type} not found)")
if len(dep_issues) > 5:
console.print(f" ... and {len(dep_issues) - 5} more")
errors += len(dep_issues)
else:
console.print(f" [green]✅ Dependencies (no broken references)[/green]")
# File reference validation
file_issues = validate_file_references(yaml_dir)
if file_issues:
console.print(f" [yellow]⚠️ Missing Files ({len(file_issues)} warnings):[/yellow]")
for issue in file_issues[:5]: # Show first 5
console.print(f" • {issue.missing_file} (ref in {issue.experiment_id})")
if len(file_issues) > 5:
console.print(f" ... and {len(file_issues) - 5} more")
warnings += len(file_issues)
else:
console.print(f" [green]✅ File References (all files found)[/green]")
# Old path references
path_issues = find_old_path_references(yaml_dir)
if path_issues:
console.print(f" [yellow]⚠️ Old Path References ({len(path_issues)} found):[/yellow]")
for issue in path_issues[:3]: # Show first 3
console.print(f" • {issue.file.name}:{issue.line} ({issue.old_path})")
if len(path_issues) > 3:
console.print(f" ... and {len(path_issues) - 3} more")
warnings += len(path_issues)
if fix:
console.print(f" [cyan]🔧 Fixing path references...[/cyan]")
fixed_count, actions = fix_yaml_paths(yaml_dir, dry_run=False)
console.print(f" [green]✓ Fixed {fixed_count} files[/green]")
warnings -= len(path_issues) # Remove warnings since we fixed them
console.print()
# ========================================================================
# 3. Database Status
# ========================================================================
console.print("[bold]Database:[/bold]")
db_path = Path(".hae/experiments.db")
if not db_path.exists():
console.print(" [yellow]⚠️ .hae/experiments.db not found[/yellow]")
console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'\n")
warnings += 1
else:
try:
db = ExperimentDB()
def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0]
exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0]
if def_count == 0 and status.new_definitions_exists:
console.print(f" [yellow]⚠️ Database empty (0 definitions)[/yellow]")
console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'\n")
warnings += 1
else:
console.print(f" [green]✅ Database OK ({def_count} definitions, {exec_count} executions)[/green]")
# Integrity check
if verbose:
result = db.cursor.execute("PRAGMA integrity_check").fetchone()
if result[0] == 'ok':
console.print(" [green]✅ Integrity check passed[/green]")
else:
console.print(f" [red]❌ Integrity check failed: {result[0]}[/red]")
errors += 1
db.close()
console.print()
except Exception as e:
console.print(f" [red]❌ Database error: {e}[/red]\n")
errors += 1
# ========================================================================
# 4. Summary
# ========================================================================
console.print("[bold]Summary:[/bold]")
if errors == 0 and warnings == 0:
console.print("[green]✅ All healthy - no issues found![/green]")
sys.exit(0)
elif errors == 0:
console.print(f"[yellow]⚠️ {warnings} warning(s) found[/yellow]")
sys.exit(0)
else:
console.print(f"[red]❌ {errors} error(s), {warnings} warning(s) found[/red]")
if not fix:
console.print("[dim]Run with --fix to attempt automatic repairs[/dim]")
sys.exit(1)
@experiment.command('migrate')
@click.option('--dry-run', is_flag=True, default=True, help='Preview changes without applying (default)')
@click.option('--apply', is_flag=True, help='Apply migration (bypasses dry-run)')
@click.option('--no-backup', is_flag=True, help='Skip backup creation (not recommended)')
@click.option('--force', is_flag=True, help='Skip confirmation prompts')
def experiment_migrate(dry_run, apply, no_backup, force):
"""Migrate experiment structure from old to new format.
This command safely migrates your experiments from the old directory
structure (experiments/definitions/) to the new structure (definitions/).
By default, shows a dry-run preview and asks for confirmation.
Examples:
hae experiment migrate # Preview migration (dry-run)
hae experiment migrate --apply # Apply migration with confirmation
hae experiment migrate --apply --force # Apply without confirmation
"""
from hae.migration import (
check_migration_status,
create_backup,
migrate_directories,
fix_yaml_paths
)
from hae.db import ExperimentDB
console.print("[bold cyan]🔄 Experiment Structure Migration[/bold cyan]\n")
# Check if migration is needed
status = check_migration_status()
if not status.needs_migration:
console.print("[green]✓ Already migrated - no action needed![/green]")
console.print("[dim]Current structure:[/dim]")
console.print(f" • definitions/ ({status.new_definitions_count} files)")
console.print(f" • experiment-history/ ({status.new_history_count} records)")
return
# Determine if we're actually applying or just showing dry-run
actually_apply = apply and not dry_run
if actually_apply:
console.print("[bold]Migration Plan:[/bold]\n")
else:
console.print("[bold][DRY RUN] Preview of changes:[/bold]\n")
# ========================================================================
# 1. Backup
# ========================================================================
if not no_backup:
if actually_apply:
console.print("1. [bold]Backup:[/bold]")
try:
backup_dir = create_backup()
console.print(f" [green]✓ Created backup: {backup_dir}/[/green]\n")
except Exception as e:
console.print(f" [red]❌ Backup failed: {e}[/red]")
console.print("[yellow]Aborting migration for safety[/yellow]")
sys.exit(1)
else:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
console.print(f"1. [bold]Backup:[/bold]")
console.print(f" ✓ Would create .migration-backup-{timestamp}/\n")
else:
console.print("1. [bold]Backup:[/bold] Skipped (--no-backup)\n")
# ========================================================================
# 2. Move Directories
# ========================================================================
console.print("2. [bold]Move Directories:[/bold]")
success, actions = migrate_directories(dry_run=(not actually_apply))
for action in actions:
if action.startswith("Would"):
console.print(f" {action}")
elif action.startswith("Moved") or action.startswith("Removed"):
console.print(f" [green]✓ {action}[/green]")
elif action.startswith("Error"):
console.print(f" [red]❌ {action}[/red]")
if not success:
console.print("\n[red]Migration failed - check errors above[/red]")
sys.exit(1)
console.print()
# ========================================================================
# 3. Fix Path References
# ========================================================================
console.print("3. [bold]Update Path References:[/bold]")
yaml_dir = Path("definitions/experiments")
if yaml_dir.exists():
fixed_count, fix_actions = fix_yaml_paths(yaml_dir, dry_run=(not actually_apply))
if fixed_count > 0:
for action in fix_actions[:5]: # Show first 5
if action.startswith("Would"):
console.print(f" {action}")
elif action.startswith("Fixed"):
console.print(f" [green]✓ {action}[/green]")
elif action.startswith("Error"):
console.print(f" [red]❌ {action}[/red]")
if len(fix_actions) > 5:
console.print(f" ... and {len(fix_actions) - 5} more")
else:
console.print(" [dim]No path references to fix[/dim]")
else:
console.print(" [yellow]⚠️ definitions/experiments/ not found (will be created)[/yellow]")
console.print()
# ========================================================================
# 4. Initialize Database
# ========================================================================
console.print("4. [bold]Initialize Database:[/bold]")
if actually_apply:
console.print(" [green]✓ Run 'hae experiment sync' after migration to build database[/green]")
else:
console.print(" ✓ Would run 'hae experiment sync' to build database")
console.print()
# ========================================================================
# Confirmation & Apply
# ========================================================================
if not actually_apply:
console.print("[yellow]This was a DRY RUN - no changes were made[/yellow]\n")
if not force:
import click
if click.confirm("Apply migration?", default=False):
console.print("\n[cyan]Applying migration...[/cyan]\n")
# Recursively call with --apply --force
import subprocess
result = subprocess.run(
["hae", "experiment", "migrate", "--apply", "--force"] +
(["--no-backup"] if no_backup else []),
cwd=Path.cwd()
)
sys.exit(result.returncode)
else:
console.print("[dim]Migration cancelled - run 'hae experiment migrate --apply' when ready[/dim]")
sys.exit(0)
else:
console.print("[dim]Run with --apply to execute migration[/dim]")
sys.exit(0)
else:
console.print("[green]✅ Migration complete![/green]\n")
console.print("[cyan]Verifying...[/cyan]")
# Run doctor to verify
import subprocess
subprocess.run(["hae", "experiment", "doctor", "--migration-only"], cwd=Path.cwd())
@main.group()
def compare():
"""Compare files from different runs or implementations.
Supports comparison of images, JSON metrics, NPY arrays, time-series, and text files.
"""
pass
@compare.command('images')
@click.argument('files', nargs=-1, type=click.Path(exists=True))
@click.option('--pattern', multiple=True, help='Glob pattern to find images (can be used multiple times)')
@click.option('--dirs', multiple=True, help='Directories to search for images (can be used multiple times)')
@click.option('--match-by', type=click.Choice(['name', 'index', 'metadata']), default='name',
help='How to match files across directories (default: name)')
@click.option('--plot', type=click.Choice(['grid', 'diff', 'before-after', 'statistics', 'multi-panel']),
default='grid', help='Plot type (default: grid)')
@click.option('--grid-cols', type=int, help='Number of columns for grid plot')
@click.option('--labels', multiple=True, help='Labels for images (use --labels for each label, or comma-separated)')
@click.option('--title', help='Plot title')
@click.option('--output', type=click.Path(), help='Output file path (for single comparison)')
@click.option('--output-dir', type=click.Path(), help='Output directory (for batch comparisons)')
@click.option('--output-format', type=click.Choice(['png', 'pdf', 'svg', 'html']), default='png',
help='Output format (default: png)')
@click.option('--dpi', type=int, default=100, help='DPI for output (default: 100)')
@click.option('--diff', is_flag=True, help='Show pixel differences')
@click.option('--metrics', multiple=True, help='Metrics to compute (mse, ssim, mean_diff, max_diff)')
@click.option('--verbose', '-v', is_flag=True, help='Verbose output')
@click.option('--dry-run', is_flag=True, help='Show what would be done without executing')
def compare_images(files, pattern, dirs, match_by, plot, grid_cols, labels, title, output,
output_dir, output_format, dpi, diff, metrics, verbose, dry_run):
"""Compare image files.
Can compare specific files, files matching patterns, or files in directories.
Examples:
hae compare images img1.png img2.png img3.png --plot grid --labels "A" "B" "C"
hae compare images --pattern "**/anomaly_maps/*.png" --plot grid --output-dir ./comparisons
hae compare images --dirs dir1 dir2 dir3 --match-by name --plot grid --labels "process" "evaluator" "ad"
"""
from .commands.compare import compare_images_cli
# Convert click types to lists
patterns = list(pattern) if pattern else []
dirs_list = list(dirs) if dirs else None
# Handle labels - support both comma-separated and multiple flags
labels_list = []
if labels:
for label in labels:
# Support comma-separated labels: --labels "a,b,c"
if ',' in label:
labels_list.extend([l.strip() for l in label.split(',')])
else:
labels_list.append(label)
metrics_list = list(metrics) if metrics else None
# If files are provided directly, add them as patterns
if files:
patterns.extend([str(f) for f in files])
try:
compare_images_cli(
patterns=patterns,
dirs=dirs_list,
match_by=match_by,
plot=plot,
grid_cols=grid_cols,
labels=labels_list,
title=title,
output=output,
output_dir=output_dir,
output_format=output_format,
dpi=dpi,
diff=diff,
metrics=metrics_list,
verbose=verbose,
dry_run=dry_run
)
console.print("[green]✓ Comparison complete[/green]")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
if verbose:
import traceback
traceback.print_exc()
sys.exit(1)
@compare.command('json')
@click.argument('files', nargs=-1, type=click.Path(exists=True))
@click.option('--pattern', multiple=True, help='Glob pattern to find JSON files')
@click.option('--dirs', multiple=True, help='Directories to search for JSON files')
@click.option('--extract', multiple=True, help='Extract specific values (e.g., "evaluation.auroc")')
@click.option('--extract-all', is_flag=True, help='Extract all metrics')
@click.option('--diff', is_flag=True, help='Show differences')
@click.option('--table', is_flag=True, help='Show as table')
@click.option('--compute-deltas', is_flag=True, help='Compute differences/deltas')
@click.option('--output', type=click.Path(), help='Output file path')
@click.option('--output-format', type=click.Choice(['csv', 'json', 'markdown', 'table']), default='table',
help='Output format (default: table)')
@click.option('--verbose', '-v', is_flag=True, help='Verbose output')
@click.option('--dry-run', is_flag=True, help='Show what would be done without executing')
def compare_json(files, pattern, dirs, extract, extract_all, diff, table, compute_deltas,
output, output_format, verbose, dry_run):
"""Compare JSON files (metrics, results, etc.).
Examples:
hae compare json metrics1.json metrics2.json --extract evaluation.auroc --table
hae compare json --pattern "**/final_metrics.json" --extract-all --output comparison.csv
"""
console.print("[yellow]JSON comparison not yet implemented[/yellow]")
console.print("[dim]This feature is planned but not yet available[/dim]")
@compare.command('npy')
@click.argument('files', nargs=-1, type=click.Path(exists=True))
@click.option('--pattern', multiple=True, help='Glob pattern to find NPY files')
@click.option('--dirs', multiple=True, help='Directories to search for NPY files')
@click.option('--metrics', multiple=True, help='Metrics to compute (mse, mae, correlation)')
@click.option('--output', type=click.Path(), help='Output file path')
@click.option('--verbose', '-v', is_flag=True, help='Verbose output')
@click.option('--dry-run', is_flag=True, help='Show what would be done without executing')
def compare_npy(files, pattern, dirs, metrics, output, verbose, dry_run):
"""Compare NPY array files.
Examples:
hae compare npy array1.npy array2.npy --metrics mse,mae
hae compare npy --pattern "**/*.npy" --metrics mse --output comparison.txt
"""
console.print("[yellow]NPY comparison not yet implemented[/yellow]")
console.print("[dim]This feature is planned but not yet available[/dim]")
@compare.command('time-series')
@click.argument('files', nargs=-1, type=click.Path(exists=True))
@click.option('--pattern', multiple=True, help='Glob pattern to find time-series files')
@click.option('--metric', help='Metric to extract (e.g., "training.loss")')
@click.option('--x-axis', help='X-axis field (e.g., "epoch")')
@click.option('--plot-type', type=click.Choice(['overlaid', 'separate', 'diff']), default='overlaid',
help='Plot type (default: overlaid)')
@click.option('--labels', multiple=True, help='Labels for series')
@click.option('--output', type=click.Path(), help='Output file path')
@click.option('--verbose', '-v', is_flag=True, help='Verbose output')
@click.option('--dry-run', is_flag=True, help='Show what would be done without executing')
def compare_time_series(files, pattern, metric, x_axis, plot_type, labels, output, verbose, dry_run):
"""Compare time-series data.
Examples:
hae compare time-series run1/metrics.json run2/metrics.json --metric training.loss --x-axis epoch
"""
console.print("[yellow]Time-series comparison not yet implemented[/yellow]")
console.print("[dim]This feature is planned but not yet available[/dim]")
@compare.command('text')
@click.argument('files', nargs=-1, type=click.Path(exists=True))
@click.option('--pattern', multiple=True, help='Glob pattern to find text files')
@click.option('--diff', is_flag=True, help='Line-by-line diff')
@click.option('--extract-patterns', multiple=True, help='Extract values using regex patterns')
@click.option('--summary', is_flag=True, help='Summary comparison')
@click.option('--output', type=click.Path(), help='Output file path')
@click.option('--verbose', '-v', is_flag=True, help='Verbose output')
@click.option('--dry-run', is_flag=True, help='Show what would be done without executing')
def compare_text(files, pattern, diff, extract_patterns, summary, output, verbose, dry_run):
"""Compare text/log files.
Examples:
hae compare text log1.txt log2.txt --diff --output diff.txt
hae compare text --pattern "**/*.log" --extract-patterns "loss=([0-9.]+)" --summary
"""
console.print("[yellow]Text comparison not yet implemented[/yellow]")
console.print("[dim]This feature is planned but not yet available[/dim]")
if __name__ == '__main__':
main()
name: iris-eval-ad-detect-and-evaluate-256
description: |
Complete pipeline that evaluates a DeCo-Diff model using AnomalyDetector for Iris dataset.
sweep:
type: grid
params:
sweep_dataset: [ "80_Real_merged_head", "280_False_merged_defect_modulo_obvious_defect", "280_False_merged_obvious_defect", "280_False_merged_normal_with_perfect_plain_10_with_split_only_test", "280_False_merged_normal_with_perfect_plain_10_with_split_only_train"]
variables:
# Dataset configuration
dataset: iris
data_dir: D:/HDS_Mask/SL_DL_Ori/save/EUV
object_category: all
model_size: UNet_L
image_size: 256
center_crop: 256
model_path: D:/HDS/DiOodMi___250917/results/decodiff_iris_basic_augmentation_with_split/checkpoints/epoch-857.pt
csv_split_file: "../splits/iris-split_{sweep_dataset}.csv" # Custom CSV split file
eval_output_dir_ad: "{EXECUTION_DIR}/{sweep_dataset}"
save_plot_path : "{EXECUTION_DIR}/{sweep_dataset}/save_plot_path"
# Evaluation configuration (deterministic settings required!)
crop_size: 256
patch_size: 256 # For eval-process (processor uses patch_size)
batch_size: 10
pad_px: 2
stride: 1
directions: h v
tta_batch_size: 2
# Evaluation processing parameters
reverse_steps: 5 # Number of reverse diffusion steps
num_workers: 0 # Number of DataLoader workers (0 for Windows compatibility)
vae_type: ema # VAE type (ema or mse)
shift_method: mirror # TTA shift augmentation method
fuse_method: pct25 # Method to fuse multiple shifted predictions
img_enlarge_px: 4 # Image enlargement for interpolation method
# Processor-specific parameters
annotation_dir: ./annotations # Directory containing annotation JSON files (REQUIRED for eval-process)
split: test # Data split to process (for eval-process)
# Reconstruction saving options (optional)
save_reconstructions_flag: "" # Flag for saving VAE and Deco-Diff reconstructions as PNG
save_reconstructions_npy_flag: "" # Flag for saving raw reconstruction arrays as NPY files
save_raw_anomaly_maps_flag: "" # Flag for saving raw anomaly map arrays as NPY files
# Visualization variant options (optional)
save_image_variants_flag: "" # Variants to save: continuous binary grayscale absolute
save_colormaps_flag: "" # Colormaps: jet hot viridis plasma magma inferno turbo
save_normalization: minmax # Normalization method: minmax zscore percentile raw
normal_center: 125.0 # Center value for grayscale normalization
save_binary_thresholds_flag: "" # Thresholds: e.g., "--save-binary-thresholds 5.0 10.0 15.0"
# Dataset sampling options (optional - set to null to use full dataset)
bootstrap_samples: null
bootstrap_seed: 42
last_n: null
range_start: null
range_count: 10
random_n: null
random_n_with_replacement: null
random_seed: 42
# Backend settings for determinism
torch_deterministic: true
cudnn_benchmark: false
allow_tf32: false
experiments:
# ===================================================================
# Step 1: Detection Only (Save TTA fundamentals to NPY cache)
# ===================================================================
- name: detect-and-evaluate
description: "Detection and evaluate"
pipeline: ../pipelines/eval-ad.yaml
variables:
mode: "detect_and_evaluate"
model_path: "{model_path}"
dataset: "{dataset}"
data_dir: "{data_dir}"
object_category: "{object_category}"
csv_split_file: "{csv_split_file}"
model_size: "{model_size}"
crop_size: "{crop_size}"
batch_size: "{batch_size}"
pad_px: "{pad_px}"
tta_batch_size: "{tta_batch_size}"
stride: "{stride}"
directions: "{directions}"
shift_method: "{shift_method}"
fuse_method: "{fuse_method}"
img_enlarge_px: "{img_enlarge_px}"
reverse_steps: "{reverse_steps}"
num_workers: "{num_workers}"
vae_type: "{vae_type}"
output_dir: "{eval_output_dir_ad}"
# Reconstruction saving options
save_reconstructions_flag: "{save_reconstructions_flag}"
save_reconstructions_npy_flag: "{save_reconstructions_npy_flag}"
save_raw_anomaly_maps_flag: "{save_raw_anomaly_maps_flag}"
# Visualization options
save_image_variants_flag: "{save_image_variants_flag}"
save_colormaps_flag: "{save_colormaps_flag}"
save_normalization: "{save_normalization}"
normal_center: "{normal_center}"
save_binary_thresholds_flag: "{save_binary_thresholds_flag}"
# Evaluation sampling options (set to null to use full dataset)
bootstrap_samples: "{bootstrap_samples}"
bootstrap_seed: "{bootstrap_seed}"
first_n: "{eval_first_n}"
last_n: "{last_n}"
range_start: "{range_start}"
range_count: "{range_count}"
random_n: "{random_n}"
random_n_with_replacement: "{random_n_with_replacement}"
random_seed: "{random_seed}"
backend:
type: local
parallel: false # Run sequentially: detect-only → evaluate-only
continue_on_error: false # Stop if detection fails
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment